File size: 1,317 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
Retrieval pool of candidates
"""
from dataclasses import dataclass
from typing import List, Dict
import json


class OpenAIQARetrievePool(object):
    def __init__(self, data_path):
        with open(data_path, 'r') as f:
            data = json.load(f)
        self.data = []
        for d in data:
            if isinstance(d['qa_column'], List):
                d['qa_column'] = '|'.join(d['qa_column'])
            qa_item = QAItem(
                id=d['id'],
                qa_question=d['qa_question'],
                qa_column=d['qa_column'],
                qa_answer=d['qa_answer'],
                table=d['table'],
                title=d['title']
            )
            self.data.append(qa_item)

        self.pointer = 0

    def __iter__(self):
        return self

    def __next__(self):
        pointer = self.pointer
        if pointer < len(self):
            self.pointer += 1
            return self.data[pointer]
        else:
            self.pointer = 0
            raise StopIteration

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)


@dataclass
class QAItem(object):
    id: int = None
    qa_question: str = None
    qa_column: str = None
    qa_answer: str = None
    table: Dict = None
    title: str = None