Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
import torch | |
import torch.nn as nn | |
from transformers import AutoModel | |
import re | |
class CodeSimilarityClassifier(nn.Module): | |
def __init__(self, model_name="microsoft/codebert-base", num_labels=3): | |
super().__init__() | |
self.encoder = AutoModel.from_pretrained(model_name) | |
self.dropout = nn.Dropout(0.1) | |
# Create a more powerful classification head | |
hidden_size = self.encoder.config.hidden_size | |
self.classifier = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), | |
nn.LayerNorm(hidden_size), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
nn.Linear(hidden_size, 512), | |
nn.LayerNorm(512), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
nn.Linear(512, num_labels) | |
) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
pooled_output = outputs.pooler_output | |
logits = self.classifier(pooled_output) | |
return logits | |
def extract_features(source_code, test_code_1, test_code_2): | |
"""Extract specific features to help the model identify similarities""" | |
# Extract test fixtures | |
fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1) | |
fixture1 = fixture1.group(1) if fixture1 else "" | |
fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2) | |
fixture2 = fixture2.group(1) if fixture2 else "" | |
# Extract test names | |
name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1) | |
name1 = name1.group(1) if name1 else "" | |
name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2) | |
name2 = name2.group(1) if name2 else "" | |
# Extract assertions | |
assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1) | |
assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2) | |
# Extract function/method calls | |
calls1 = re.findall(r'(\w+)\s*\(', test_code_1) | |
calls2 = re.findall(r'(\w+)\s*\(', test_code_2) | |
# Create explicit feature section | |
same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE" | |
common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2])) | |
common_calls = set(calls1).intersection(set(calls2)) | |
features = ( | |
f"METADATA: {same_fixture} | " | |
f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | " | |
f"NAME1: {name1} | NAME2: {name2} | " | |
f"COMMON_ASSERTIONS: {len(common_assertions)} | " | |
f"COMMON_CALLS: {len(common_calls)} | " | |
f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}" | |
) | |
return features |