File size: 5,316 Bytes
81a794d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
from openai import OpenAI
from tqdm import tqdm
from collections import defaultdict
import traceback
import httpx

from backend.utils.data_process import split_to_file_diff, split_to_section
from backend.section_infer_helper.base_helper import BaseHelper


logger = logging.getLogger(__name__)


class OnlineLLMHelper(BaseHelper):

    MAX_LENGTH = 4096
    MAX_NEW_TOKENS = 16
    
    PREDEF_MODEL = ["gpt-3.5-turbo", "deepseek-chat", "qwen-coder-plus", "gpt-4-turbo", "gpt-4o", "gemini-1.5-pro-latest", "claude-3-5-sonnet-20241022"]
    
    MODEL_CONFIGS = defaultdict(lambda: {
        "supported_languages": ["C", "C++", "Java", "Python"],
    })
    
    SYSTEM_PROMPT = "You are an expert in code vulnerability and patch fixes."
    
    def generate_instruction(language, file_name, patch, section, message = None):
        instruction =  "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name)
        instruction += "[Patch]\n{}\n".format(patch)
        instruction += "[A section of this patch]\n{}\n".format(section)
        if message is not None and message != "":
            instruction += "[Message of the Patch]\n{}\n".format(message)
        
        return instruction


    def __init__(self):
        self.model_name = None
        self.url = None
        self.key = None


    def generate_message(filename, patch, section, patch_message = None):
        ext = filename.split(".")[-1]
        language = BaseHelper._get_lang_by_ext(ext)
        user_message = OnlineLLMHelper.generate_instruction(language, filename, patch, section, patch_message)
        user_message = user_message.split(" ")
        user_message = user_message[:OnlineLLMHelper.MAX_LENGTH]
        user_message = " ".join(user_message)
        messages = [
            {
                "role": "system",
                "content": OnlineLLMHelper.SYSTEM_PROMPT
            },
            {
                "role": "user",
                "content": user_message
            }
        ]
        return messages
    

    def load_model(self, model_name, url, api_key):
        self.model_name = model_name

        self.openai_client = OpenAI(
            base_url = url,
            api_key = api_key,
            timeout=httpx.Timeout(15.0)
        )
    
    
    def infer(self, diff_code, message = None, batch_size=1):
        if self.model_name is None:
            raise RuntimeError("Model is not loaded")
        
        results = {}
        input_list = []
        file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(OnlineLLMHelper.MODEL_CONFIGS[self.model_name]["supported_languages"]))
        for file_a, _, file_diff in file_diff_list:
            sections = split_to_section(file_diff)
            file_name = file_a.removeprefix("a/")
            results[file_name] = []
            for section in sections:
                input_list.append(BaseHelper.InputData(file_name, section, section, message))

        input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size)
        assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob)
        for i in range(len(input_list)):
            file_name = input_list[i].filename
            section = input_list[i].section
            output_text_i = output_text[i].lower()
            output_prob_i = output_prob[i]
            results[file_name].append({
                "section": section,
                "predict": -1 if output_text_i == "error" else 1 if "yes" in output_text_i else 0,
                "conf": output_prob_i
            })

        return results
        
    
    def do_infer(self, input_list, batch_size = 1):
        input_prompt = []
        for input_data in input_list:
            input_prompt.append(OnlineLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg))
        
        if len(input_prompt) > 0:
            logger.info("Example input prompt: %s", input_prompt[0])
        output_text = []
        for prompt, input_data in tqdm(zip(input_prompt, input_list), desc="Inferencing", unit = "section", total = len(input_prompt)):
            try:
                response = self.openai_client.chat.completions.create(
                    messages = prompt,
                    model = self.model_name,
                    max_completion_tokens = OnlineLLMHelper.MAX_NEW_TOKENS
                )
                output_text.append(response.choices[0].message.content)
            except KeyboardInterrupt:
                logging.error("KeyboardInterrupted")
                break
            except Exception as e:
                logger.error(f"Error: {e}")
                logger.error(f"Error inferencing: {input_data.filename} - {input_data.section}")
                logger.error(traceback.format_exc())
                
                output_text.append("error")
                continue
                
                # break
        
        output_prob = [1.0] * len(output_text)
        return input_prompt, output_text, output_prob


online_llm_helper = OnlineLLMHelper()