import torch import torch.nn as nn from transformers import AutoModel import re class CodeSimilarityClassifier(nn.Module): def __init__(self, model_name="microsoft/codebert-base", num_labels=3): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) self.dropout = nn.Dropout(0.1) # Create a more powerful classification head hidden_size = self.encoder.config.hidden_size self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.LayerNorm(hidden_size), nn.GELU(), nn.Dropout(0.1), nn.Linear(hidden_size, 512), nn.LayerNorm(512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_labels) ) def forward(self, input_ids, attention_mask): outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) pooled_output = outputs.pooler_output logits = self.classifier(pooled_output) return logits def extract_features(source_code, test_code_1, test_code_2): """Extract specific features to help the model identify similarities""" # Extract test fixtures fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1) fixture1 = fixture1.group(1) if fixture1 else "" fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2) fixture2 = fixture2.group(1) if fixture2 else "" # Extract test names name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1) name1 = name1.group(1) if name1 else "" name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2) name2 = name2.group(1) if name2 else "" # Extract assertions assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1) assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2) # Extract function/method calls calls1 = re.findall(r'(\w+)\s*\(', test_code_1) calls2 = re.findall(r'(\w+)\s*\(', test_code_2) # Create explicit feature section same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE" common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2])) common_calls = set(calls1).intersection(set(calls2)) features = ( f"METADATA: {same_fixture} | " f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | " f"NAME1: {name1} | NAME2: {name2} | " f"COMMON_ASSERTIONS: {len(common_assertions)} | " f"COMMON_CALLS: {len(common_calls)} | " f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}" ) return features