File size: 2,801 Bytes
1306f0a
 
 
a5cd505
1306f0a
a5cd505
 
1306f0a
a5cd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1306f0a
a5cd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
from transformers import AutoModel
import re

class CodeSimilarityClassifier(nn.Module):
    def __init__(self, model_name="microsoft/codebert-base", num_labels=3):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)

        # Create a more powerful classification head
        hidden_size = self.encoder.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_labels)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)

        return logits

def extract_features(source_code, test_code_1, test_code_2):
    """Extract specific features to help the model identify similarities"""

    # Extract test fixtures
    fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1)
    fixture1 = fixture1.group(1) if fixture1 else ""

    fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2)
    fixture2 = fixture2.group(1) if fixture2 else ""

    # Extract test names
    name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1)
    name1 = name1.group(1) if name1 else ""

    name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2)
    name2 = name2.group(1) if name2 else ""

    # Extract assertions
    assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1)
    assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2)

    # Extract function/method calls
    calls1 = re.findall(r'(\w+)\s*\(', test_code_1)
    calls2 = re.findall(r'(\w+)\s*\(', test_code_2)

    # Create explicit feature section
    same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE"
    common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2]))
    common_calls = set(calls1).intersection(set(calls2))

    features = (
        f"METADATA: {same_fixture} | "
        f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | "
        f"NAME1: {name1} | NAME2: {name2} | "
        f"COMMON_ASSERTIONS: {len(common_assertions)} | "
        f"COMMON_CALLS: {len(common_calls)} | "
        f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}"
    )

    return features