|
"""Generate answers with GPT-3.5""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import time |
|
import concurrent.futures |
|
|
|
import openai |
|
import tqdm |
|
import shortuuid |
|
|
|
MODEL = "gpt-3.5-turbo" |
|
MODEL_ID = "gpt-3.5-turbo:20230327" |
|
|
|
|
|
def get_answer(question_id: int, question: str, max_tokens: int): |
|
ans = { |
|
"answer_id": shortuuid.uuid(), |
|
"question_id": question_id, |
|
"model_id": MODEL_ID, |
|
} |
|
for _ in range(3): |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model=MODEL, |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{ |
|
"role": "user", |
|
"content": question, |
|
}, |
|
], |
|
max_tokens=max_tokens, |
|
) |
|
ans["text"] = response["choices"][0]["message"]["content"] |
|
return ans |
|
except Exception as e: |
|
print("[ERROR]", e) |
|
ans["text"] = "#ERROR#" |
|
time.sleep(1) |
|
return ans |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="ChatGPT answer generation.") |
|
parser.add_argument("-q", "--question") |
|
parser.add_argument("-o", "--output") |
|
parser.add_argument( |
|
"--max-tokens", |
|
type=int, |
|
default=1024, |
|
help="maximum number of tokens produced in the output", |
|
) |
|
args = parser.parse_args() |
|
|
|
questions_dict = {} |
|
with open(os.path.expanduser(args.question)) as f: |
|
for line in f: |
|
if not line: |
|
continue |
|
q = json.loads(line) |
|
questions_dict[q["question_id"]] = q["text"] |
|
|
|
answers = [] |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
|
futures = [] |
|
for qid, question in questions_dict.items(): |
|
future = executor.submit(get_answer, qid, question, args.max_tokens) |
|
futures.append(future) |
|
|
|
for future in tqdm.tqdm( |
|
concurrent.futures.as_completed(futures), total=len(futures) |
|
): |
|
answers.append(future.result()) |
|
|
|
answers.sort(key=lambda x: x["question_id"]) |
|
|
|
with open(os.path.expanduser(args.output), "w") as f: |
|
table = [json.dumps(ans) for ans in answers] |
|
f.write("\n".join(table)) |
|
|