File size: 1,765 Bytes
7de3018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import json
import os
import pandas as pd

ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../")


class Question_Image_Match_Classifier(object):
    """result are from a T5-3b model finetuned on train set of MMQA."""

    def __init__(self):
        self.whether_retrieve_image = None
        self.qi_pairs_should_retrieve = None
        self.load_retrieve_info()
        self.caption_info = None
        with open(os.path.join(ROOT_DIR, "utils", "mmqa", "mmqa_captions.json"), "r") as f:
            self.caption_info = json.load(f)

    def load_retrieve_info(self):
        df_qc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qc_mmqa_dev.csv"))
        whether_retrieve_image = {}
        for index, row in df_qc.iterrows():
            _id = row['id']
            prediction = row['prediction']
            whether_retrieve_image[_id] = True if prediction == "['yes']" else False
        self.whether_retrieve_image = whether_retrieve_image

        df_qimc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qimc_mmqa_dev.csv"))
        qi_pairs_should_retrieve = {}
        for index, row in df_qimc.iterrows():
            qa = row['question'].lower()
            prediction = row['prediction']
            qi_pairs_should_retrieve[qa] = True if prediction == "['yes']" else False
        self.qi_pairs_should_retrieve = qi_pairs_should_retrieve

    def judge_match(self, _id, question, pic):
        # fixme: hardcode since it is done in pipeline, change that in the future
        if not self.whether_retrieve_image[_id]:
            return False
        image_caption = self.caption_info[os.path.split(pic)[-1].split(".")[0]]
        return self.qi_pairs_should_retrieve['qa: {} \n{}'.format(question.lower(), image_caption.lower())]