Spaces:
Runtime error
Runtime error
File size: 3,452 Bytes
ee7776a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import json
import random
import argparse
import os
from tqdm import tqdm
from utils import dotdict
from stqdm import stqdm
import openai
from model import gpt, gpt_usage, OPENAI_API_KEY
from prompts import auditor_prompt, auditor_format_constrain
from prompts import topk_prompt1, topk_prompt2
completion_tokens = 0
prompt_tokens = 0
def remove_spaces(s):
return ' '.join(s.split())
def prompt_wrap(prompt, format_constraint, code, topk):
return prompt + code + format_constraint + topk_prompt1.format(topk=topk) + topk_prompt2
def auditor_response_parse(auditor_outputs):
output_list = []
for auditor_output in auditor_outputs:
try:
start_idx = auditor_output.find("{")
end_idx = auditor_output.rfind("}")
data = json.loads(auditor_output[start_idx: end_idx+1])
except:
print("parsing json fail.")
continue
try:
output_list += data["output_list"]
except:
print("No vulnerability detected")
continue
return output_list
def solve(args, code):
bug_info_list = []
auditor_input = prompt_wrap(auditor_prompt, auditor_format_constrain, code, args.topk)
try:
auditor_outputs = gpt(auditor_input, model=args.backend, temperature=args.temperature, n=args.num_auditor)
bug_info_list = auditor_response_parse(auditor_outputs)
except Exception as e:
print(e)
return bug_info_list
def run(args):
if args.get('openai_api_key') is None:
openai.api_key = OPENAI_API_KEY
else:
openai.api_key = args.openai_api_key
# log output file
log_dir = f"./src/logs/auditor_{args.backend}_{args.temperature}_top{args.topk}_{args.num_auditor}"
for file_name in stqdm(os.listdir(args.data_dir)):
all_bug_info_list = []
if not file_name.endswith(".sol"):
continue
with open(f"{args.data_dir}/{file_name}", "r") as f:
code = f.read()
# remove space
code = remove_spaces(code)
# auditing
bug_info_list = solve(args, code)
if len(bug_info_list) == 0: #Sometimes the query fails because the model does not strictly follow the format
print("{index} failed".format(index=file_name))
continue
for info in bug_info_list:
info.update({"file_name": file_name})
all_bug_info_list.append(info)
file = f"{log_dir}/{file_name.replace('.sol', '.json')}"
os.makedirs(os.path.dirname(file), exist_ok=True)
with open(file, 'w') as f:
json.dump(all_bug_info_list, f, indent=4)
def parse_args():
args = argparse.ArgumentParser()
args.add_argument('--backend', type=str, choices=['gpt-3.5-turbo','gpt-4', 'gpt-4-turbo-preview'], default='gpt-4-turbo-preview')
args.add_argument('--temperature', type=float, default=0.7)
args.add_argument('--data_dir', type=str, default="data/CVE_clean")
args.add_argument('--topk', type=int, default=5) # the topk per each auditor
args.add_argument('--num_auditor', type=int, default=1)
args = args.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print(args)
run(args)
def mainfnc(args=dotdict):
# args = parse_args()
# print(args)
run(args) |