Spaces:
Runtime error
Runtime error
File size: 5,683 Bytes
f6f97d8 b15814b f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
Generate nsql and questions.
"""
from typing import Dict, List, Union, Tuple
import openai
import time
from generation.prompt import PromptBuilder
class Generator(object):
"""
Codex generation wrapper.
"""
def __init__(self, args, keys=None):
self.args = args
self.__keys = keys
self.current_key_id = 0
# if the args provided, will initialize with the prompt builder for full usage
self.prompt_builder = PromptBuilder(args) if args else None
def prompt_row_truncate(
self,
prompt: str,
num_rows_to_remain: int,
table_end_token: str = '*/',
):
"""
Fit prompt into max token limits by row truncation.
"""
table_end_pos = prompt.rfind(table_end_token)
assert table_end_pos != -1
prompt_part1, prompt_part2 = prompt[:table_end_pos], prompt[table_end_pos:]
prompt_part1_lines = prompt_part1.split('\n')[::-1]
trunc_line_index = None
for idx, line in enumerate(prompt_part1_lines):
if '\t' not in line:
continue
row_id = int(line.split('\t')[0])
if row_id <= num_rows_to_remain:
trunc_line_index = idx
break
new_prompt_part1 = '\n'.join(prompt_part1_lines[trunc_line_index:][::-1])
prompt = new_prompt_part1 + '\n' + prompt_part2
return prompt
def build_few_shot_prompt_from_file(
self,
file_path: str,
n_shots: int
):
"""
Build few-shot prompt for generation from file.
"""
with open(file_path, 'r') as f:
lines = f.readlines()
few_shot_prompt_list = []
one_shot_prompt = ''
last_line = None
for line in lines:
if line == '\n' and last_line == '\n':
few_shot_prompt_list.append(one_shot_prompt)
one_shot_prompt = ''
else:
one_shot_prompt += line
last_line = line
few_shot_prompt_list.append(one_shot_prompt)
few_shot_prompt_list = few_shot_prompt_list[:n_shots]
few_shot_prompt_list[-1] = few_shot_prompt_list[
-1].strip() # It is essential for prompting to remove extra '\n'
few_shot_prompt = '\n'.join(few_shot_prompt_list)
return few_shot_prompt
def build_generate_prompt(
self,
data_item: Dict,
generate_type: Tuple
):
"""
Build the generate prompt
"""
return self.prompt_builder.build_generate_prompt(
**data_item,
generate_type=generate_type
)
def generate_one_pass(
self,
prompts: List[Tuple],
verbose: bool = False
):
"""
Generate one pass with codex according to the generation phase.
"""
result_idx_to_eid = []
for p in prompts:
result_idx_to_eid.extend([p[0]] * self.args.sampling_n)
prompts = [p[1] for p in prompts]
start_time = time.time()
result = self._call_codex_api(
engine=self.args.engine,
prompt=prompts,
max_tokens=self.args.max_generation_tokens,
temperature=self.args.temperature,
top_p=self.args.top_p,
n=self.args.sampling_n,
stop=self.args.stop_tokens
)
print(f'Openai api one inference time: {time.time() - start_time}')
if verbose:
print('\n', '*' * 20, 'Codex API Call', '*' * 20)
for prompt in prompts:
print(prompt)
print('\n')
print('- - - - - - - - - - ->>')
# parse api results
response_dict = dict()
for idx, g in enumerate(result['choices']):
try:
text = g['text']
logprob = sum(g['logprobs']['token_logprobs'])
eid = result_idx_to_eid[idx]
eid_pairs = response_dict.get(eid, None)
if eid_pairs is None:
eid_pairs = []
response_dict[eid] = eid_pairs
eid_pairs.append((text, logprob))
if verbose:
print(text)
except ValueError as e:
if verbose:
print('----------- Error Msg--------')
print(e)
print(text)
print('-----------------------------')
pass
return response_dict
def _call_codex_api(
self,
engine: str,
prompt: Union[str, List],
max_tokens,
temperature: float,
top_p: float,
n: int,
stop: List[str]
):
start_time = time.time()
result = None
while result is None:
try:
key = self.keys[self.current_key_id]
self.current_key_id = (self.current_key_id + 1) % len(self.keys)
result = openai.Completion.create(
engine=engine,
prompt=prompt,
api_key=key,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
logprobs=1
)
print('Openai api inference time:', time.time() - start_time)
return result
except Exception as e:
print(e, 'Retry.')
time.sleep(5)
|