Spaces:
Runtime error
Runtime error
""" | |
Generate nsql and questions. | |
""" | |
from typing import Dict, List, Union, Tuple | |
import openai | |
import time | |
from generation.prompt import PromptBuilder | |
class Generator(object): | |
""" | |
Codex generation wrapper. | |
""" | |
def __init__(self, args, keys=None): | |
self.args = args | |
self.__keys = keys | |
self.current_key_id = 0 | |
# if the args provided, will initialize with the prompt builder for full usage | |
self.prompt_builder = PromptBuilder(args) if args else None | |
def prompt_row_truncate( | |
self, | |
prompt: str, | |
num_rows_to_remain: int, | |
table_end_token: str = '*/', | |
): | |
""" | |
Fit prompt into max token limits by row truncation. | |
""" | |
table_end_pos = prompt.rfind(table_end_token) | |
assert table_end_pos != -1 | |
prompt_part1, prompt_part2 = prompt[:table_end_pos], prompt[table_end_pos:] | |
prompt_part1_lines = prompt_part1.split('\n')[::-1] | |
trunc_line_index = None | |
for idx, line in enumerate(prompt_part1_lines): | |
if '\t' not in line: | |
continue | |
row_id = int(line.split('\t')[0]) | |
if row_id <= num_rows_to_remain: | |
trunc_line_index = idx | |
break | |
new_prompt_part1 = '\n'.join(prompt_part1_lines[trunc_line_index:][::-1]) | |
prompt = new_prompt_part1 + '\n' + prompt_part2 | |
return prompt | |
def build_few_shot_prompt_from_file( | |
self, | |
file_path: str, | |
n_shots: int | |
): | |
""" | |
Build few-shot prompt for generation from file. | |
""" | |
with open(file_path, 'r') as f: | |
lines = f.readlines() | |
few_shot_prompt_list = [] | |
one_shot_prompt = '' | |
last_line = None | |
for line in lines: | |
if line == '\n' and last_line == '\n': | |
few_shot_prompt_list.append(one_shot_prompt) | |
one_shot_prompt = '' | |
else: | |
one_shot_prompt += line | |
last_line = line | |
few_shot_prompt_list.append(one_shot_prompt) | |
few_shot_prompt_list = few_shot_prompt_list[:n_shots] | |
few_shot_prompt_list[-1] = few_shot_prompt_list[ | |
-1].strip() # It is essential for prompting to remove extra '\n' | |
few_shot_prompt = '\n'.join(few_shot_prompt_list) | |
return few_shot_prompt | |
def build_generate_prompt( | |
self, | |
data_item: Dict, | |
generate_type: Tuple | |
): | |
""" | |
Build the generate prompt | |
""" | |
return self.prompt_builder.build_generate_prompt( | |
**data_item, | |
generate_type=generate_type | |
) | |
def generate_one_pass( | |
self, | |
prompts: List[Tuple], | |
verbose: bool = False | |
): | |
""" | |
Generate one pass with codex according to the generation phase. | |
""" | |
result_idx_to_eid = [] | |
for p in prompts: | |
result_idx_to_eid.extend([p[0]] * self.args.sampling_n) | |
prompts = [p[1] for p in prompts] | |
start_time = time.time() | |
result = self._call_codex_api( | |
engine=self.args.engine, | |
prompt=prompts, | |
max_tokens=self.args.max_generation_tokens, | |
temperature=self.args.temperature, | |
top_p=self.args.top_p, | |
n=self.args.sampling_n, | |
stop=self.args.stop_tokens | |
) | |
print(f'Openai api one inference time: {time.time() - start_time}') | |
if verbose: | |
print('\n', '*' * 20, 'Codex API Call', '*' * 20) | |
for prompt in prompts: | |
print(prompt) | |
print('\n') | |
print('- - - - - - - - - - ->>') | |
# parse api results | |
response_dict = dict() | |
for idx, g in enumerate(result['choices']): | |
try: | |
text = g['text'] | |
logprob = sum(g['logprobs']['token_logprobs']) | |
eid = result_idx_to_eid[idx] | |
eid_pairs = response_dict.get(eid, None) | |
if eid_pairs is None: | |
eid_pairs = [] | |
response_dict[eid] = eid_pairs | |
eid_pairs.append((text, logprob)) | |
if verbose: | |
print(text) | |
except ValueError as e: | |
if verbose: | |
print('----------- Error Msg--------') | |
print(e) | |
print(text) | |
print('-----------------------------') | |
pass | |
return response_dict | |
def _call_codex_api( | |
self, | |
engine: str, | |
prompt: Union[str, List], | |
max_tokens, | |
temperature: float, | |
top_p: float, | |
n: int, | |
stop: List[str] | |
): | |
start_time = time.time() | |
result = None | |
while result is None: | |
try: | |
key = self.keys[self.current_key_id] | |
self.current_key_id = (self.current_key_id + 1) % len(self.keys) | |
result = openai.Completion.create( | |
engine=engine, | |
prompt=prompt, | |
api_key=key, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
n=n, | |
stop=stop, | |
logprobs=1 | |
) | |
print('Openai api inference time:', time.time() - start_time) | |
return result | |
except Exception as e: | |
print(e, 'Retry.') | |
time.sleep(5) | |