CodeBert_Redundant_Detection_Task / model_definition.py
Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
import torch
import torch.nn as nn
from transformers import AutoModel
import re
class CodeSimilarityClassifier(nn.Module):
def __init__(self, model_name="microsoft/codebert-base", num_labels=3):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(0.1)
# Create a more powerful classification head
hidden_size = self.encoder.config.hidden_size
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_labels)
)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
def extract_features(source_code, test_code_1, test_code_2):
"""Extract specific features to help the model identify similarities"""
# Extract test fixtures
fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1)
fixture1 = fixture1.group(1) if fixture1 else ""
fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2)
fixture2 = fixture2.group(1) if fixture2 else ""
# Extract test names
name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1)
name1 = name1.group(1) if name1 else ""
name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2)
name2 = name2.group(1) if name2 else ""
# Extract assertions
assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1)
assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2)
# Extract function/method calls
calls1 = re.findall(r'(\w+)\s*\(', test_code_1)
calls2 = re.findall(r'(\w+)\s*\(', test_code_2)
# Create explicit feature section
same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE"
common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2]))
common_calls = set(calls1).intersection(set(calls2))
features = (
f"METADATA: {same_fixture} | "
f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | "
f"NAME1: {name1} | NAME2: {name2} | "
f"COMMON_ASSERTIONS: {len(common_assertions)} | "
f"COMMON_CALLS: {len(common_calls)} | "
f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}"
)
return features