File size: 3,452 Bytes
ee7776a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import json
import random
import argparse
import os
from tqdm import tqdm
from utils import dotdict
from stqdm import stqdm

import openai
from model import gpt, gpt_usage, OPENAI_API_KEY
from prompts import auditor_prompt, auditor_format_constrain
from prompts import topk_prompt1, topk_prompt2

completion_tokens = 0
prompt_tokens = 0

def remove_spaces(s):
    return ' '.join(s.split())

def prompt_wrap(prompt, format_constraint, code, topk):
    return prompt + code + format_constraint + topk_prompt1.format(topk=topk) + topk_prompt2

def auditor_response_parse(auditor_outputs):
    output_list = []
    for auditor_output in auditor_outputs:
        try:
            start_idx = auditor_output.find("{")
            end_idx = auditor_output.rfind("}")
            data = json.loads(auditor_output[start_idx: end_idx+1])
        except:
            print("parsing json fail.")
            continue
        try:
            output_list += data["output_list"]
        except:
            print("No vulnerability detected")
            continue

    return output_list

def solve(args, code):

    bug_info_list = []
    auditor_input = prompt_wrap(auditor_prompt, auditor_format_constrain, code, args.topk)

    try:
        auditor_outputs = gpt(auditor_input, model=args.backend, temperature=args.temperature, n=args.num_auditor)
        bug_info_list = auditor_response_parse(auditor_outputs)
    except Exception as e:
        print(e)

    return bug_info_list

def run(args):

    if args.get('openai_api_key') is None:
        openai.api_key = OPENAI_API_KEY
    else:
        openai.api_key = args.openai_api_key

    # log output file
    log_dir = f"./src/logs/auditor_{args.backend}_{args.temperature}_top{args.topk}_{args.num_auditor}"

    for file_name in stqdm(os.listdir(args.data_dir)):

        all_bug_info_list = []

        if not file_name.endswith(".sol"):
            continue

        with open(f"{args.data_dir}/{file_name}", "r") as f:
            code = f.read()
        # remove space
        code = remove_spaces(code)

        # auditing
        bug_info_list = solve(args, code)

        if len(bug_info_list) == 0: #Sometimes the query fails because the model does not strictly follow the format
            print("{index} failed".format(index=file_name))
            continue

        for info in bug_info_list:
            info.update({"file_name": file_name})
            all_bug_info_list.append(info)

        file = f"{log_dir}/{file_name.replace('.sol', '.json')}"
        os.makedirs(os.path.dirname(file), exist_ok=True)

        with open(file, 'w') as f:
            json.dump(all_bug_info_list, f, indent=4)

def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--backend', type=str, choices=['gpt-3.5-turbo','gpt-4', 'gpt-4-turbo-preview'], default='gpt-4-turbo-preview')
    args.add_argument('--temperature', type=float, default=0.7)
    args.add_argument('--data_dir', type=str, default="data/CVE_clean")
    args.add_argument('--topk', type=int, default=5) # the topk per each auditor
    args.add_argument('--num_auditor', type=int, default=1)

    args = args.parse_args()
    return args

if __name__ == '__main__':

    args = parse_args()
    print(args)
    run(args)

def mainfnc(args=dotdict):
    # args = parse_args()
    # print(args)
    run(args)