|
""" |
|
Original work: |
|
https://github.com/sangHa0411/CloneDetection/blob/main/utils/preprocessor.py |
|
|
|
Copyright (c) 2022 Sangha Park(sangha110495), Young Jin Ahn(snoop2head) |
|
|
|
All credits to the original authors. |
|
""" |
|
import re |
|
import torch |
|
from transformers import Pipeline |
|
|
|
|
|
class FunctionPreprocessor: |
|
def get_function(self, code): |
|
results = [] |
|
fn_list = re.findall("\ndef [a-zA-Z0-9_]+\(", code) |
|
|
|
for fn in fn_list: |
|
results.append(fn[4:-1].strip()) |
|
return results |
|
|
|
def determine_function(self, code, function_name): |
|
num = len(re.findall("[^a-zA-Z]" + function_name + "[^a-zA-Z]", code)) |
|
return False if num <= 1 else True |
|
|
|
def delete_function(self, code, name): |
|
start_id, _ = re.search("def " + name, code).span() |
|
ptr = start_id |
|
|
|
while ptr < len(code) - 1: |
|
if code[ptr] == "\n" and re.search("[a-zA-Z]", code[ptr + 1]) is not None: |
|
break |
|
ptr += 1 |
|
|
|
if ptr != len(code) - 1: |
|
end_id = ptr |
|
code = code[:start_id] + code[end_id:] |
|
|
|
return code |
|
|
|
def preprocess(self, code): |
|
code = "\n" + code |
|
fn_list = self.get_function(code) |
|
if len(fn_list) == 0: |
|
return code |
|
|
|
for fn in fn_list: |
|
flag = self.determine_function(code, fn) |
|
|
|
if flag == False: |
|
code = self.delete_function(code, fn) |
|
|
|
return code |
|
|
|
|
|
class AnnotationPreprocessor: |
|
def search(self, sen_list, string): |
|
for i, sen in enumerate(sen_list): |
|
if string in sen: |
|
return i |
|
return -1 |
|
|
|
def delete_annotation_block(self, code, string): |
|
sens = [sen for sen in code.split("\n")] |
|
|
|
start_id = self.search(sens, string) |
|
end_id = self.search(sens[start_id + 1 :], string) |
|
if end_id != -1: |
|
end_id += start_id + 1 |
|
code = sens[:start_id] + sens[end_id + 1 :] |
|
else: |
|
code = sens[:start_id] + sens[start_id + 1 :] |
|
|
|
code = "\n".join(code) |
|
return code |
|
|
|
def delete_block(self, code, string): |
|
while string in code: |
|
code = self.delete_annotation_block(code, string) |
|
return code |
|
|
|
def delete_annotation(self, code): |
|
sens = code.split("\n") |
|
|
|
sens_processed = [] |
|
for sen in sens: |
|
if "#" in sen: |
|
index = sen.index("#") |
|
sen = sen[:index] |
|
sens_processed.append(sen) |
|
|
|
return "\n".join(sens_processed) |
|
|
|
def delete_import(self, code): |
|
sens = code.split("\n") |
|
|
|
sens_processed = [] |
|
for sen in sens: |
|
if "import" not in sen: |
|
sens_processed.append(sen) |
|
|
|
return "\n".join(sens_processed) |
|
|
|
def preprocess(self, code): |
|
code = self.delete_block(code, '"""') |
|
code = self.delete_block(code, "'''") |
|
code = self.delete_annotation(code) |
|
code = self.delete_import(code) |
|
code = re.sub("\s+", " ", code).strip() |
|
return code |
|
|
|
|
|
def preprocessor(code, instance): |
|
processed_code = instance.preprocess(code) |
|
return processed_code if processed_code.strip() else code |
|
|
|
|
|
def token_to_inputs(feature): |
|
inputs = {} |
|
for k, v in feature.items(): |
|
inputs[k] = torch.tensor(v).unsqueeze(0) |
|
|
|
return inputs |
|
|
|
|
|
class CloneDetectionPipeline(Pipeline): |
|
fn_preprocessor = FunctionPreprocessor() |
|
an_preprocessor = AnnotationPreprocessor() |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, inputs): |
|
code1 = inputs[0] |
|
code2 = inputs[1] |
|
if code1.strip() == "" or code2.strip() == "": |
|
ture_prob = float(code1.strip() == code2.strip()) |
|
return {"skip": True, "output": {False: 1 - ture_prob, True: ture_prob}} |
|
|
|
code1 = preprocessor( |
|
preprocessor(code1, self.fn_preprocessor), self.an_preprocessor |
|
) |
|
code2 = preprocessor( |
|
preprocessor(code2, self.fn_preprocessor), self.an_preprocessor |
|
) |
|
|
|
feature1 = self.tokenizer( |
|
code1, code2, max_length=512, return_token_type_ids=False, truncation=True |
|
) |
|
feature2 = self.tokenizer( |
|
code2, code1, max_length=512, return_token_type_ids=False, truncation=True |
|
) |
|
|
|
return { |
|
"inputs1": token_to_inputs(feature1), |
|
"inputs2": token_to_inputs(feature2), |
|
} |
|
|
|
def _forward(self, model_inputs): |
|
if model_inputs.get("skip", False): |
|
return model_inputs |
|
|
|
inputs1 = model_inputs["inputs1"] |
|
inputs2 = model_inputs["inputs2"] |
|
|
|
logits1 = self.model(**inputs1).logits[0] |
|
logits2 = self.model(**inputs2).logits[0] |
|
logits = (logits1 + logits2) / 2 |
|
|
|
return {"logits": logits} |
|
|
|
def postprocess(self, model_outputs): |
|
if model_outputs.get("skip", False): |
|
return model_outputs["output"] |
|
|
|
probs = model_outputs["logits"].softmax(-1).tolist() |
|
|
|
return {False: probs[0], True: probs[1]} |
|
|