File size: 1,989 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
from fengshen.pipelines.multiplechoice import UniMCPipelines
import os
import json
import copy
from tqdm import tqdm

def load_data(data_path):
    with open(data_path, 'r', encoding='utf8') as f:
        lines = f.readlines()
        samples = [json.loads(line) for line in tqdm(lines)]
    return samples


def comp_acc(pred_data,test_data):
    corr=0
    for i in range(len(pred_data)):
        if pred_data[i]['label']==test_data[i]['label']:
            corr+=1
    return corr/len(pred_data)


def main():
    total_parser = argparse.ArgumentParser("TASK NAME")
    total_parser.add_argument('--data_dir', default='./data', type=str)
    total_parser.add_argument('--train_data', default='train.json', type=str)
    total_parser.add_argument('--valid_data', default='dev.json', type=str)
    total_parser.add_argument('--test_data', default='test.json', type=str)
    total_parser.add_argument('--output_path', default='', type=str)
    
    total_parser = UniMCPipelines.piplines_args(total_parser)
    args = total_parser.parse_args()

    train_data = load_data(os.path.join(args.data_dir, args.train_data))
    dev_data = load_data(os.path.join(args.data_dir, args.valid_data))
    test_data = load_data(os.path.join(args.data_dir, args.test_data))

    # dev_data = dev_data[:200]
    dev_data_ori=copy.deepcopy(dev_data)

    model = UniMCPipelines(args, args.pretrained_model_path)
    
    print(args.data_dir)
            
    if args.train:
        model.train(train_data, dev_data)
    result = model.predict(dev_data)
    for line in result[:20]:
        print(line)

    acc=comp_acc(result,dev_data_ori)
    print('acc:',acc)
    
    if args.output_path != '':
        test_result = model.predict(test_data)
        with open(args.output_path, 'w', encoding='utf8') as f:
            for line in test_result:
                json_data=json.dumps(line,ensure_ascii=False)
                f.write(json_data+'\n')


if __name__ == "__main__":
    main()