Spaces:
Running
Running
File size: 5,316 Bytes
81a794d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import logging
from openai import OpenAI
from tqdm import tqdm
from collections import defaultdict
import traceback
import httpx
from backend.utils.data_process import split_to_file_diff, split_to_section
from backend.section_infer_helper.base_helper import BaseHelper
logger = logging.getLogger(__name__)
class OnlineLLMHelper(BaseHelper):
MAX_LENGTH = 4096
MAX_NEW_TOKENS = 16
PREDEF_MODEL = ["gpt-3.5-turbo", "deepseek-chat", "qwen-coder-plus", "gpt-4-turbo", "gpt-4o", "gemini-1.5-pro-latest", "claude-3-5-sonnet-20241022"]
MODEL_CONFIGS = defaultdict(lambda: {
"supported_languages": ["C", "C++", "Java", "Python"],
})
SYSTEM_PROMPT = "You are an expert in code vulnerability and patch fixes."
def generate_instruction(language, file_name, patch, section, message = None):
instruction = "[TASK]\nHere is a patch in {} language and a section of this patch for a source code file with path {}. Determine if the patch section fixes any software vulnerabilities. Output 'yes' or 'no' and do not output any other text.\n".format(language, file_name)
instruction += "[Patch]\n{}\n".format(patch)
instruction += "[A section of this patch]\n{}\n".format(section)
if message is not None and message != "":
instruction += "[Message of the Patch]\n{}\n".format(message)
return instruction
def __init__(self):
self.model_name = None
self.url = None
self.key = None
def generate_message(filename, patch, section, patch_message = None):
ext = filename.split(".")[-1]
language = BaseHelper._get_lang_by_ext(ext)
user_message = OnlineLLMHelper.generate_instruction(language, filename, patch, section, patch_message)
user_message = user_message.split(" ")
user_message = user_message[:OnlineLLMHelper.MAX_LENGTH]
user_message = " ".join(user_message)
messages = [
{
"role": "system",
"content": OnlineLLMHelper.SYSTEM_PROMPT
},
{
"role": "user",
"content": user_message
}
]
return messages
def load_model(self, model_name, url, api_key):
self.model_name = model_name
self.openai_client = OpenAI(
base_url = url,
api_key = api_key,
timeout=httpx.Timeout(15.0)
)
def infer(self, diff_code, message = None, batch_size=1):
if self.model_name is None:
raise RuntimeError("Model is not loaded")
results = {}
input_list = []
file_diff_list = split_to_file_diff(diff_code, BaseHelper._get_lang_ext(OnlineLLMHelper.MODEL_CONFIGS[self.model_name]["supported_languages"]))
for file_a, _, file_diff in file_diff_list:
sections = split_to_section(file_diff)
file_name = file_a.removeprefix("a/")
results[file_name] = []
for section in sections:
input_list.append(BaseHelper.InputData(file_name, section, section, message))
input_prompt, output_text, output_prob = self.do_infer(input_list, batch_size)
assert len(input_list) == len(input_prompt) == len(output_text) == len(output_prob)
for i in range(len(input_list)):
file_name = input_list[i].filename
section = input_list[i].section
output_text_i = output_text[i].lower()
output_prob_i = output_prob[i]
results[file_name].append({
"section": section,
"predict": -1 if output_text_i == "error" else 1 if "yes" in output_text_i else 0,
"conf": output_prob_i
})
return results
def do_infer(self, input_list, batch_size = 1):
input_prompt = []
for input_data in input_list:
input_prompt.append(OnlineLLMHelper.generate_message(input_data.filename, input_data.patch, input_data.section, input_data.patch_msg))
if len(input_prompt) > 0:
logger.info("Example input prompt: %s", input_prompt[0])
output_text = []
for prompt, input_data in tqdm(zip(input_prompt, input_list), desc="Inferencing", unit = "section", total = len(input_prompt)):
try:
response = self.openai_client.chat.completions.create(
messages = prompt,
model = self.model_name,
max_completion_tokens = OnlineLLMHelper.MAX_NEW_TOKENS
)
output_text.append(response.choices[0].message.content)
except KeyboardInterrupt:
logging.error("KeyboardInterrupted")
break
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Error inferencing: {input_data.filename} - {input_data.section}")
logger.error(traceback.format_exc())
output_text.append("error")
continue
# break
output_prob = [1.0] * len(output_text)
return input_prompt, output_text, output_prob
online_llm_helper = OnlineLLMHelper()
|