import json import random import argparse import os from tqdm import tqdm from utils import dotdict from stqdm import stqdm import openai from model import gpt, gpt_usage, OPENAI_API_KEY from prompts import auditor_prompt, auditor_format_constrain from prompts import topk_prompt1, topk_prompt2 completion_tokens = 0 prompt_tokens = 0 def remove_spaces(s): return ' '.join(s.split()) def prompt_wrap(prompt, format_constraint, code, topk): return prompt + code + format_constraint + topk_prompt1.format(topk=topk) + topk_prompt2 def auditor_response_parse(auditor_outputs): output_list = [] for auditor_output in auditor_outputs: try: start_idx = auditor_output.find("{") end_idx = auditor_output.rfind("}") data = json.loads(auditor_output[start_idx: end_idx+1]) except: print("parsing json fail.") continue try: output_list += data["output_list"] except: print("No vulnerability detected") continue return output_list def solve(args, code): bug_info_list = [] auditor_input = prompt_wrap(auditor_prompt, auditor_format_constrain, code, args.topk) try: auditor_outputs = gpt(auditor_input, model=args.backend, temperature=args.temperature, n=args.num_auditor) bug_info_list = auditor_response_parse(auditor_outputs) except Exception as e: print(e) return bug_info_list def run(args): if args.get('openai_api_key') is None: openai.api_key = OPENAI_API_KEY else: openai.api_key = args.openai_api_key # log output file log_dir = f"./src/logs/auditor_{args.backend}_{args.temperature}_top{args.topk}_{args.num_auditor}" for file_name in stqdm(os.listdir(args.data_dir)): all_bug_info_list = [] if not file_name.endswith(".sol"): continue with open(f"{args.data_dir}/{file_name}", "r") as f: code = f.read() # remove space code = remove_spaces(code) # auditing bug_info_list = solve(args, code) if len(bug_info_list) == 0: #Sometimes the query fails because the model does not strictly follow the format print("{index} failed".format(index=file_name)) continue for info in bug_info_list: info.update({"file_name": file_name}) all_bug_info_list.append(info) file = f"{log_dir}/{file_name.replace('.sol', '.json')}" os.makedirs(os.path.dirname(file), exist_ok=True) with open(file, 'w') as f: json.dump(all_bug_info_list, f, indent=4) def parse_args(): args = argparse.ArgumentParser() args.add_argument('--backend', type=str, choices=['gpt-3.5-turbo','gpt-4', 'gpt-4-turbo-preview'], default='gpt-4-turbo-preview') args.add_argument('--temperature', type=float, default=0.7) args.add_argument('--data_dir', type=str, default="data/CVE_clean") args.add_argument('--topk', type=int, default=5) # the topk per each auditor args.add_argument('--num_auditor', type=int, default=1) args = args.parse_args() return args if __name__ == '__main__': args = parse_args() print(args) run(args) def mainfnc(args=dotdict): # args = parse_args() # print(args) run(args)