""" Generate nsql and questions. """ from typing import Dict, List, Union, Tuple import openai import time from generation.prompt import PromptBuilder class Generator(object): """ Codex generation wrapper. """ def __init__(self, args, keys=None): self.args = args self.__keys = keys self.current_key_id = 0 # if the args provided, will initialize with the prompt builder for full usage self.prompt_builder = PromptBuilder(args) if args else None def prompt_row_truncate( self, prompt: str, num_rows_to_remain: int, table_end_token: str = '*/', ): """ Fit prompt into max token limits by row truncation. """ table_end_pos = prompt.rfind(table_end_token) assert table_end_pos != -1 prompt_part1, prompt_part2 = prompt[:table_end_pos], prompt[table_end_pos:] prompt_part1_lines = prompt_part1.split('\n')[::-1] trunc_line_index = None for idx, line in enumerate(prompt_part1_lines): if '\t' not in line: continue row_id = int(line.split('\t')[0]) if row_id <= num_rows_to_remain: trunc_line_index = idx break new_prompt_part1 = '\n'.join(prompt_part1_lines[trunc_line_index:][::-1]) prompt = new_prompt_part1 + '\n' + prompt_part2 return prompt def build_few_shot_prompt_from_file( self, file_path: str, n_shots: int ): """ Build few-shot prompt for generation from file. """ with open(file_path, 'r') as f: lines = f.readlines() few_shot_prompt_list = [] one_shot_prompt = '' last_line = None for line in lines: if line == '\n' and last_line == '\n': few_shot_prompt_list.append(one_shot_prompt) one_shot_prompt = '' else: one_shot_prompt += line last_line = line few_shot_prompt_list.append(one_shot_prompt) few_shot_prompt_list = few_shot_prompt_list[:n_shots] few_shot_prompt_list[-1] = few_shot_prompt_list[ -1].strip() # It is essential for prompting to remove extra '\n' few_shot_prompt = '\n'.join(few_shot_prompt_list) return few_shot_prompt def build_generate_prompt( self, data_item: Dict, generate_type: Tuple ): """ Build the generate prompt """ return self.prompt_builder.build_generate_prompt( **data_item, generate_type=generate_type ) def generate_one_pass( self, prompts: List[Tuple], verbose: bool = False ): """ Generate one pass with codex according to the generation phase. """ result_idx_to_eid = [] for p in prompts: result_idx_to_eid.extend([p[0]] * self.args.sampling_n) prompts = [p[1] for p in prompts] start_time = time.time() result = self._call_codex_api( engine=self.args.engine, prompt=prompts, max_tokens=self.args.max_generation_tokens, temperature=self.args.temperature, top_p=self.args.top_p, n=self.args.sampling_n, stop=self.args.stop_tokens ) print(f'Openai api one inference time: {time.time() - start_time}') if verbose: print('\n', '*' * 20, 'Codex API Call', '*' * 20) for prompt in prompts: print(prompt) print('\n') print('- - - - - - - - - - ->>') # parse api results response_dict = dict() for idx, g in enumerate(result['choices']): try: text = g['text'] logprob = sum(g['logprobs']['token_logprobs']) eid = result_idx_to_eid[idx] eid_pairs = response_dict.get(eid, None) if eid_pairs is None: eid_pairs = [] response_dict[eid] = eid_pairs eid_pairs.append((text, logprob)) if verbose: print(text) except ValueError as e: if verbose: print('----------- Error Msg--------') print(e) print(text) print('-----------------------------') pass return response_dict def _call_codex_api( self, engine: str, prompt: Union[str, List], max_tokens, temperature: float, top_p: float, n: int, stop: List[str] ): start_time = time.time() result = None while result is None: try: key = self.keys[self.current_key_id] self.current_key_id = (self.current_key_id + 1) % len(self.keys) result = openai.Completion.create( engine=engine, prompt=prompt, api_key=key, max_tokens=max_tokens, temperature=temperature, top_p=top_p, n=n, stop=stop, logprobs=1 ) print('Openai api inference time:', time.time() - start_time) return result except Exception as e: print(e, 'Retry.') time.sleep(5)