File size: 9,207 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import random

from generation.prompt import OpenAIQAPromptBuilder
from generation.generator import Generator
from retrieval.retriever import OpenAIQARetriever
from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem

num_parallel_prompts = 10
num_qa_shots = 8
infinite_rows_len = 50  # If the table contain rows larger than this number, it will be handled rows by rows.
max_tokens = 1024
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../")


class OpenAIQAModel(object):
    def __init__(self, args, keys=None):
        super().__init__()

        # Prepare keys
        self.key_current_id = 0
        self.keys = keys
        random.seed(42)
        random.shuffle(self.keys)

        retrieve_pool = OpenAIQARetrievePool(
            data_path=os.path.join(ROOT_DIR, args.qa_retrieve_pool_file)
        )
        self.retriever = OpenAIQARetriever(retrieve_pool)
        self.generator = Generator(args=None, keys=self.keys)  # Just to use its call api function

        self.prompting_method = 'new_db'
        self.answer_split_token: str = ';'
        self.db_mapping_token = "\t"

    def call_openai_api_completion(self, prompt):
        completion = self.generator._call_codex_api(engine="code-davinci-002",
                                                    prompt=prompt,
                                                    max_tokens=max_tokens,
                                                    temperature=0,
                                                    top_p=1,
                                                    n=1,
                                                    stop=["\n\n"])
        return completion

    def call_openai_for_completion_text(self, prompt, openai_usage_type="completion"):
        if openai_usage_type == "completion":
            completion = self.call_openai_api_completion(prompt)
            return completion.choices[0].text
        else:
            raise ValueError("The model usage type '{}' doesn't exists!".format(openai_usage_type))

    @staticmethod
    def merge_tables(tables, by='row_id'):
        assert len(set([len(_table['rows']) for _table in tables])) == 1, "Tables must have the same rows!"
        merged_header = [by]
        by_idx = tables[0]['header'].index(by)
        merged_rows = [[_row[by_idx]] for _row in tables[0]['rows']]

        for _table in tables:
            header, rows = _table['header'], _table['rows']
            for col_idx, col in enumerate(header):
                if col == by:
                    continue
                if col in merged_header:
                    # When the column is duplicate, and postfix _0, _1 etc.
                    col = "{}_{}".format(col, merged_header.count(col))
                merged_header.append(col)
                for i, row in enumerate(rows):
                    merged_rows[i].append(row[col_idx])
        return {"header": merged_header, "rows": merged_rows}

    def wrap_with_prompt_for_table_qa(self,
                                      question,
                                      sub_table,
                                      table_title=None,
                                      answer_split_token=None,
                                      qa_type="ans",
                                      prompting_method="new_db",
                                      db_mapping_token="πŸ˜…",
                                      verbose=True):
        prompt = "Question Answering Over Database:\n\n"
        if qa_type in ['map', 'ans'] and num_qa_shots > 0:
            query_item = QAItem(qa_question=question, table=sub_table, title=table_title)
            retrieved_items = self.retriever.retrieve(item=query_item, num_shots=num_qa_shots, qa_type=qa_type)
            few_shot_prompt_list = []
            for item in retrieved_items:
                one_shot_prompt = OpenAIQAPromptBuilder.build_one_shot_prompt(
                    item=item,
                    answer_split_token=answer_split_token,
                    verbose=verbose,
                    prompting_method=prompting_method,
                    db_mapping_token=db_mapping_token
                )
                few_shot_prompt_list.append(one_shot_prompt)
            few_shot_prompt = '\n'.join(few_shot_prompt_list[:num_qa_shots])
            prompt = few_shot_prompt

        prompt += "\nGive a database as shown below:\n{}\n\n".format(
            OpenAIQAPromptBuilder.table2codex_prompt(sub_table, table_title)
        )

        if qa_type == "map":
            prompt += "Q: Answer question \"{}\" row by row.".format(question)
            assert answer_split_token is not None
            if prompting_method == "basic":
                prompt += " The answer should be a list split by '{}' and have {} items in total.".format(
                    answer_split_token, len(sub_table['rows']))

        elif qa_type == "ans":
            prompt += "Q: Answer question \"{}\" for the table.".format(question)
            prompt += " "
        else:
            raise ValueError("The QA type is not supported!")

        prompt += "\n"
        if qa_type == "map":
            if prompting_method == "basic":
                prompt += "A:"
        elif qa_type == "ans":
            prompt += "A:"

        return prompt

    def qa(self, question, sub_tables, qa_type: str, verbose: bool = True, **args):
        # If it is not a problem API can handle, answer it with a QA model.
        merged_table = OpenAIQAModel.merge_tables(sub_tables)
        if verbose:
            print("Make Question {} on {}".format(question, merged_table))
        if qa_type == "map":
            # Map: col(s) -question> one col

            # Make model make a QA towards a sub-table
            # col(s) -> one col, all QA in one time
            def do_map(_table):
                _prompt = self.wrap_with_prompt_for_table_qa(question,
                                                             _table,
                                                             args['table_title'],
                                                             self.answer_split_token,
                                                             qa_type,
                                                             prompting_method=self.prompting_method,
                                                             db_mapping_token=self.db_mapping_token,
                                                             verbose=verbose)
                completion_str = self.call_openai_for_completion_text(_prompt).lower().strip(' []')

                if verbose:
                    print(f'QA map@ input:\n{_prompt}')
                    print(f'QA map@ output:\n{completion_str}')

                if self.prompting_method == "basic":
                    answers = [_answer.strip(" '").lower() for _answer in
                               completion_str.split(self.answer_split_token)]
                elif self.prompting_method == "new_db":
                    answers = [line.split(self.db_mapping_token)[-1] for line in completion_str.split("\n")[2:-1]]
                else:
                    raise ValueError("No such prompting methods: '{}'! ".format(self.prompting_method))
                return answers

            # Handle infinite rows, rows by rows.
            answers = []
            rows_len = len(merged_table['rows'])
            run_times = int(rows_len / infinite_rows_len) if rows_len % infinite_rows_len == 0 else int(
                rows_len / infinite_rows_len) + 1

            for run_idx in range(run_times):
                _table = {
                    "header": merged_table['header'],
                    "rows": merged_table['rows'][run_idx * infinite_rows_len:]
                } if run_idx == run_times - 1 else \
                    {
                        "header": merged_table['header'],
                        "rows": merged_table['rows'][run_idx * infinite_rows_len:(run_idx + 1) * infinite_rows_len]
                    }

                answers.extend(do_map(_table))
            if verbose:
                print("The map@ openai answers are {}".format(answers))
            # Add row_id in addition for finding to corresponding rows.
            return {"header": ['row_id'] + args['new_col_name_s'],
                    "rows": [[row[0], answer] for row, answer in zip(merged_table['rows'], answers)]}
        elif qa_type == "ans":
            # Ans: col(s) -question> answer
            prompt = self.wrap_with_prompt_for_table_qa(question,
                                                        merged_table,
                                                        args['table_title'],
                                                        prompting_method=self.prompting_method,
                                                        verbose=verbose)
            answers = [self.call_openai_for_completion_text(prompt).lower().strip(' []')]

            if verbose:
                print(f'QA ans@ input:\n{prompt}')
                print(f'QA ans@ output:\n{answers}')

            return answers
        else:
            raise ValueError("Please choose from map and ans in the qa usage!!")