CodeBert_Redundant_Detection_Task / model_definition.py
Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
raw
history blame
2.8 kB
import torch
import torch.nn as nn
from transformers import AutoModel
import re
class CodeSimilarityClassifier(nn.Module):
def __init__(self, model_name="microsoft/codebert-base", num_labels=3):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.1)
# Create a more powerful classification head
hidden_size = self.encoder.config.hidden_size
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_labels)
)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
def extract_features(source_code, test_code_1, test_code_2):
"""Extract specific features to help the model identify similarities"""
# Extract test fixtures
fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1)
fixture1 = fixture1.group(1) if fixture1 else ""
fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2)
fixture2 = fixture2.group(1) if fixture2 else ""
# Extract test names
name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1)
name1 = name1.group(1) if name1 else ""
name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2)
name2 = name2.group(1) if name2 else ""
# Extract assertions
assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1)
assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2)
# Extract function/method calls
calls1 = re.findall(r'(\w+)\s*\(', test_code_1)
calls2 = re.findall(r'(\w+)\s*\(', test_code_2)
# Create explicit feature section
same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE"
common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2]))
common_calls = set(calls1).intersection(set(calls2))
features = (
f"METADATA: {same_fixture} | "
f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | "
f"NAME1: {name1} | NAME2: {name2} | "
f"COMMON_ASSERTIONS: {len(common_assertions)} | "
f"COMMON_CALLS: {len(common_calls)} | "
f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}"
)
return features