File size: 3,540 Bytes
569f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import datetime
import json
import os
import torch

script_dir = os.path.dirname(os.path.realpath(__file__))

sys.path.append(os.path.join(script_dir, '..'))

from datasets.vqa_dataset import docVQADataset, docVQATESTDataset, textVQADataset


print(torch.__version__)

import numpy as np

from eval_utils.getargs import parse_args
from eval_utils.vqa_evaluate import *


def get_model(args):
    if args.model_name == '':
        raise Exception('Model name cannot be empty str!')
    from models.MiniCPM.minicpmv import MiniCPM_V, MiniCPM_V_2_6
    model_path = args.model_path
    ckpt = args.ckpt
    
    if args.model_name == 'minicpmv':
        model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
    elif args.model_name == 'minicpmv26':
        model = MiniCPM_V_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
    else:
        raise Exception(f"Unexpected Moedel Name {args.model_name}!")
    
    return model


def main(args):
    np.random.seed(0)
    max_sample_num = None

    torch.distributed.init_process_group(
        backend='nccl',
        world_size=int(os.getenv('WORLD_SIZE', '1')),
        rank=int(os.getenv('RANK', '0')),
    )
    torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
    print(f'Init Rank-{torch.distributed.get_rank()}')
    if torch.distributed.is_initialized():
        args.device = torch.device(f"cuda:{torch.cuda.current_device()}")

    model = get_model(args)
    
    result = {}
    time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    if args.eval_textVQA or args.eval_all:
        dataset = textVQADataset(args.textVQA_image_dir, args.textVQA_ann_path)
        if max_sample_num is not None:
            dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
        acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time, \
                batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
        result['textVQA'] = acc

    if args.eval_docVQA or args.eval_all:
        dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
        if max_sample_num is not None:
            dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
        acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
        result['docVQA'] = acc

    if args.eval_docVQATest or args.eval_all:
        target_dataset = "docVQATest"
        dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
        if max_sample_num is not None:
            dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
        acc = evaluate_VQA(model, dataset, args.model_name, target_dataset, time, batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
        result['docVQATest'] = acc
    
    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
        return None

    result_path = os.path.join(os.path.join(args.answer_path, args.model_name), 'result.json')
    
    output_flag = False
    for k, v in result.items():
        if v > 0.0:
            output_flag = True
            break
    
    if output_flag:
        with open(result_path, "w") as f:
            f.write(json.dumps(result, indent=4))


if __name__ == "__main__":
    args = parse_args()

    main(args)