|
import argparse |
|
import json |
|
import os |
|
import time |
|
|
|
import openai |
|
import tqdm |
|
import ray |
|
|
|
import shortuuid |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
MAX_API_RETRY = 5 |
|
REQ_TIME_GAP = 10 |
|
|
|
|
|
@ray.remote(num_cpus=4) |
|
def get_eval(sys_prompt, user_prompt: str, max_tokens: int): |
|
logging.basicConfig(level=logging.INFO) |
|
for i in range(MAX_API_RETRY): |
|
try: |
|
response = openai.ChatCompletion.create( |
|
model="gpt-4", |
|
messages=[ |
|
{"role": "system", "content": sys_prompt}, |
|
{ |
|
"role": "user", |
|
"content": user_prompt, |
|
}, |
|
], |
|
temperature=0.2, |
|
max_tokens=max_tokens, |
|
) |
|
content = response["choices"][0]["message"]["content"] |
|
logger.info(content) |
|
return content |
|
except Exception as e: |
|
logger.error(e) |
|
time.sleep(5) |
|
logger.error(f"Failed after {MAX_API_RETRY} retries.") |
|
return "error" |
|
|
|
|
|
def parse_score(review): |
|
try: |
|
score_pair = review.split("\n")[0] |
|
score_pair = score_pair.replace(",", " ") |
|
sp = score_pair.split(" ") |
|
if len(sp) == 2: |
|
return [float(sp[0]), float(sp[1])] |
|
else: |
|
raise Exception("Invalid score pair.") |
|
except Exception as e: |
|
logger.error( |
|
f"{e}\nContent: {review}\n" "You must manually fix the score pair." |
|
) |
|
return [-1, -1] |
|
|
|
|
|
def gen_prompt(reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2): |
|
|
|
reviewer_idx = 0 |
|
for idx, reviewer in enumerate(reviewer_jsons): |
|
if reviewer["category"] == cat: |
|
reviewer_idx = idx |
|
break |
|
prompt_id = reviewer_jsons[reviewer_idx]["prompt_id"] |
|
prompt_json = prompt_jsons[prompt_id - 1] |
|
assert prompt_json["prompt_id"] == prompt_id |
|
|
|
sys_prompt = prompt_json["system_prompt"] |
|
prompt_template = prompt_json["prompt_template"] |
|
defaults = prompt_json["defaults"] |
|
prompt = prompt_template.format( |
|
question=ques, answer_1=ans1, answer_2=ans2, **defaults |
|
) |
|
|
|
return sys_prompt, prompt, reviewer_idx + 1 |
|
|
|
|
|
def get_json_list(file_path): |
|
file_path = os.path.expanduser(file_path) |
|
with open(file_path, "r") as f: |
|
json_list = [] |
|
for line in f: |
|
json_list.append(json.loads(line)) |
|
return json_list |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") |
|
parser.add_argument("-q", "--question-file") |
|
parser.add_argument("-a", "--answer-file-list", nargs="+", default=[]) |
|
parser.add_argument("-p", "--prompt-file") |
|
parser.add_argument("-r", "--reviewer-file") |
|
parser.add_argument("-o", "--output-review-file") |
|
parser.add_argument( |
|
"--max-tokens", |
|
type=int, |
|
default=1024, |
|
help="maximum number of tokens produced in the output", |
|
) |
|
args = parser.parse_args() |
|
|
|
ray.init() |
|
|
|
question_jsons = get_json_list(args.question_file) |
|
answer1_jsons = get_json_list(args.answer_file_list[0]) |
|
answer2_jsons = get_json_list(args.answer_file_list[1]) |
|
reviewer_jsons = get_json_list(args.reviewer_file) |
|
prompt_jsons = get_json_list(args.prompt_file) |
|
|
|
|
|
assert len(question_jsons) == len(answer1_jsons) == len(answer2_jsons) |
|
|
|
handles = [] |
|
review_jsons = [] |
|
total_len = len(question_jsons) |
|
question_idx_list = list(range(total_len)) |
|
|
|
for i in question_idx_list: |
|
assert ( |
|
answer1_jsons[i]["question_id"] |
|
== question_jsons[i]["question_id"] |
|
== answer2_jsons[i]["question_id"] |
|
) |
|
|
|
ques = question_jsons[i]["text"] |
|
cat = question_jsons[i]["category"] |
|
ans1 = answer1_jsons[i]["text"] |
|
ans2 = answer2_jsons[i]["text"] |
|
sys_prompt, prompt, reviewer_id = gen_prompt( |
|
reviewer_jsons, prompt_jsons, cat, ques, ans1, ans2 |
|
) |
|
review_id = shortuuid.uuid() |
|
review_jsons.append( |
|
{ |
|
"review_id": review_id, |
|
"question_id": question_jsons[i]["question_id"], |
|
"answer1_id": answer1_jsons[i]["answer_id"], |
|
"answer2_id": answer2_jsons[i]["answer_id"], |
|
"reviewer_id": reviewer_id, |
|
"metadata": {}, |
|
} |
|
) |
|
|
|
handles.append(get_eval.remote(sys_prompt, prompt, args.max_tokens)) |
|
logger.info( |
|
f"Waiting for {REQ_TIME_GAP} seconds before sending the next request." |
|
) |
|
time.sleep(REQ_TIME_GAP) |
|
|
|
reviews = ray.get(handles) |
|
with open(f"{args.output_review_file}", "w") as output_review_file: |
|
for idx, review in enumerate(reviews): |
|
scores = parse_score(review) |
|
review_jsons[idx]["text"] = review |
|
review_jsons[idx]["score"] = scores |
|
output_review_file.write(json.dumps(review_jsons[idx]) + "\n") |
|
|