Habiba A. Elbehairy commited on
Commit
a5cd505
·
1 Parent(s): 1306f0a

Refactor Code Similarity Classifier and update Dockerfile, README, and requirements

Browse files

- Updated Dockerfile to copy application files and set CMD for uvicorn.
- Revised README title and emoji for clarity.
- Enhanced app.py with a new CodeSimilarityClassifier model and feature extraction logic.
- Improved model loading and error handling in app.py.
- Added health check and prediction endpoints with detailed logging.
- Refactored model_definition.py to define CodeSimilarityClassifier with a more powerful classification head.
- Introduced feature extraction function for better similarity detection.
- Updated requirements.txt to include necessary packages.
- Added config.json for model architecture and parameters.

Files changed (6) hide show
  1. Dockerfile +0 -2
  2. README.md +4 -4
  3. app.py +221 -77
  4. config.json +29 -0
  5. model_definition.py +72 -27
  6. requirements.txt +5 -6
Dockerfile CHANGED
@@ -14,5 +14,3 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
 
15
  COPY --chown=user . /app
16
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
17
-
18
-
 
14
 
15
  COPY --chown=user . /app
16
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: CodeBert Redundant Detection Task
3
- emoji: 🔥
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: Code Similarity Classifier
3
+ emoji: 🐨
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
  ---
app.py CHANGED
@@ -1,31 +1,37 @@
1
  import os
2
- import time
3
  import logging
4
  import torch
5
  import torch.nn.functional as F
6
  from fastapi import FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
- from transformers import AutoTokenizer, AutoConfig
10
- from model_definition import MultitaskCodeSimilarityModel
11
  from typing import List
12
  import uvicorn
13
  from datetime import datetime
 
 
 
 
14
 
15
  # Set up logging
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
17
  logger = logging.getLogger(__name__)
18
 
19
- # System information - Updated with the provided values
20
- DEPLOYMENT_DATE = "2025-06-10 15:11:04" # Updated timestamp
21
- DEPLOYED_BY = "Fastest"
22
 
23
  # Get device
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  logger.info(f"Using device: {device}")
26
 
27
- # Your Hugging Face model repository
28
  REPO_ID = "FastestAI/Redundant_Model"
 
29
 
30
  # Initialize FastAPI app
31
  app = FastAPI(
@@ -35,7 +41,7 @@ app = FastAPI(
35
  docs_url="/",
36
  )
37
 
38
- # Add CORS middleware to allow cross-origin requests
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
@@ -44,11 +50,8 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- # Define label to class mapping with CORRECT NUMBERING (1, 2, 3 instead of 0, 1, 2)
48
- label_to_class = {1: "Duplicate", 2: "Redundant", 3: "Distinct"}
49
-
50
- # Model output to API label mapping (if your model outputs 0, 1, 2 but we want 1, 2, 3)
51
- model_to_api_label = {0: 1, 1: 2, 2: 3}
52
 
53
  # Define input models for API
54
  class SourceCode(BaseModel):
@@ -69,59 +72,198 @@ class SimilarityInput(BaseModel):
69
  test_case_1: TestCase
70
  test_case_2: TestCase
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # Global variables for model and tokenizer
73
- model = None
74
  tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Load model and tokenizer on startup
77
  @app.on_event("startup")
78
  async def startup_event():
79
- global model, tokenizer
 
80
  try:
81
- logger.info(f"Loading model and tokenizer from {REPO_ID}...")
82
 
83
- # Load tokenizer directly from Hugging Face
84
- tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
 
 
 
 
 
 
85
 
86
- # Load config from Hugging Face
87
- config = AutoConfig.from_pretrained(REPO_ID)
 
 
 
 
 
 
 
88
 
89
- # Create model instance using imported MultitaskCodeSimilarityModel class
90
- model = MultitaskCodeSimilarityModel(config, tokenizer)
91
 
92
- # Load weights directly from Hugging Face
93
- state_dict = torch.hub.load_state_dict_from_url(
94
- f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin",
95
- map_location=device,
96
- check_hash=False
97
- )
98
- model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Move model to device and set to evaluation mode
101
  model.to(device)
102
  model.eval()
 
 
103
 
104
- logger.info("Model and tokenizer loaded successfully!")
105
  except Exception as e:
106
- logger.error(f"Error loading model: {e}")
107
  import traceback
108
  logger.error(traceback.format_exc())
109
  model = None
110
  tokenizer = None
111
 
112
- @app.get("/health", tags=["Health"])
113
  async def health_check():
114
  """Health check endpoint that also returns deployment information"""
115
- if model is None or tokenizer is None:
116
- return {
117
- "status": "error",
118
- "message": "Model or tokenizer not loaded",
119
- "deployment_date": DEPLOYMENT_DATE,
120
- "deployed_by": DEPLOYED_BY
121
- }
122
 
123
  return {
124
- "status": "ok",
 
 
125
  "model": REPO_ID,
126
  "device": str(device),
127
  "deployment_date": DEPLOYMENT_DATE,
@@ -133,11 +275,8 @@ async def health_check():
133
  async def predict(data: SimilarityInput):
134
  """
135
  Predict similarity class between two test cases for a given source class.
136
-
137
- Input schema follows the specified format with source_code, test_case_1, and test_case_2.
138
- Uses heuristics to detect class and method differences before using the model.
139
  """
140
- if model is None:
141
  raise HTTPException(status_code=500, detail="Model not loaded correctly")
142
 
143
  try:
@@ -150,28 +289,37 @@ async def predict(data: SimilarityInput):
150
  # Check if we can determine similarity without using the model
151
  if class_1 and class_2 and class_1 != class_2:
152
  logger.info(f"Heuristic detection: Different target classes - Distinct")
153
- api_prediction = 3 # Distinct
154
  probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
155
  elif method_1 and method_2 and not set(method_1).intersection(set(method_2)):
156
  logger.info(f"Heuristic detection: Different target methods - Distinct")
157
- api_prediction = 3 # Distinct
158
  probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
159
  else:
160
  # No clear heuristic match, use the model
161
- # Format input to match training format
162
- combined_input = (
163
- f"SOURCE CODE: {data.source_code.code}\n"
164
- f"TEST 1: {data.test_case_1.code}\n"
165
- f"TEST 2: {data.test_case_2.code}"
 
 
 
 
166
  )
167
 
168
  # Tokenize input
169
- inputs = tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
 
 
 
 
 
 
170
 
171
- # THIS IS WHERE THE MODEL IS CALLED
172
  with torch.no_grad():
173
- # Our custom model
174
- logits, _ = model(
175
  input_ids=inputs["input_ids"],
176
  attention_mask=inputs["attention_mask"]
177
  )
@@ -179,20 +327,20 @@ async def predict(data: SimilarityInput):
179
  # Process results
180
  probs = F.softmax(logits, dim=-1)[0].cpu().tolist()
181
  model_prediction = torch.argmax(logits, dim=-1).item()
182
-
183
- # Convert model prediction (0,1,2) to API prediction (1,2,3)
184
- api_prediction = model_to_api_label[model_prediction]
185
- logger.info(f"Model prediction: {label_to_class[api_prediction]}")
186
 
187
  # Map prediction to class name
188
- classification = label_to_class.get(api_prediction, "Unknown")
 
 
 
189
 
190
  return {
191
  "pair_id": data.pair_id,
192
  "test_case_1_name": data.test_case_1.name,
193
  "test_case_2_name": data.test_case_2.name,
194
  "similarity": {
195
- "score": api_prediction,
196
  "classification": classification,
197
  },
198
  "probabilities": probs
@@ -205,8 +353,17 @@ async def predict(data: SimilarityInput):
205
  logger.error(error_trace)
206
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
207
 
208
- # Example endpoint
209
- @app.get("/example", response_model=SimilarityInput, tags=["Examples"])
 
 
 
 
 
 
 
 
 
210
  async def get_example():
211
  """Get an example input to test the API"""
212
  return SimilarityInput(
@@ -233,18 +390,5 @@ async def get_example():
233
  )
234
  )
235
 
236
- @app.get("/", tags=["Root"])
237
- async def root():
238
- """
239
- Redirect to the API documentation.
240
- This is a convenience endpoint that redirects to the auto-generated docs.
241
- """
242
- return {
243
- "message": "Test Similarity Analyzer API",
244
- "documentation": "/docs",
245
- "deployment_date": DEPLOYMENT_DATE,
246
- "deployed_by": DEPLOYED_BY
247
- }
248
-
249
  if __name__ == "__main__":
250
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
1
  import os
 
2
  import logging
3
  import torch
4
  import torch.nn.functional as F
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
 
 
8
  from typing import List
9
  import uvicorn
10
  from datetime import datetime
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import requests
13
+ import re
14
+ import tempfile
15
 
16
  # Set up logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(levelname)s - %(message)s',
20
+ handlers=[logging.StreamHandler()]
21
+ )
22
  logger = logging.getLogger(__name__)
23
 
24
+ # System information - with your current values
25
+ DEPLOYMENT_DATE = "2025-06-22 22:15:13"
26
+ DEPLOYED_BY = "FASTESTAI"
27
 
28
  # Get device
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  logger.info(f"Using device: {device}")
31
 
32
+ # HuggingFace model repository path just for weights file
33
  REPO_ID = "FastestAI/Redundant_Model"
34
+ MODEL_WEIGHTS_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin"
35
 
36
  # Initialize FastAPI app
37
  app = FastAPI(
 
41
  docs_url="/",
42
  )
43
 
44
+ # Add CORS middleware
45
  app.add_middleware(
46
  CORSMiddleware,
47
  allow_origins=["*"],
 
50
  allow_headers=["*"],
51
  )
52
 
53
+ # Define label to class mapping
54
+ label_to_class = {0: "Duplicate", 1: "Redundant", 2: "Distinct"}
 
 
 
55
 
56
  # Define input models for API
57
  class SourceCode(BaseModel):
 
72
  test_case_1: TestCase
73
  test_case_2: TestCase
74
 
75
+ # Define the model class
76
+ class CodeSimilarityClassifier(torch.nn.Module):
77
+ def __init__(self, model_name="microsoft/codebert-base", num_labels=3):
78
+ super().__init__()
79
+ self.encoder = AutoModel.from_pretrained(model_name)
80
+ self.dropout = torch.nn.Dropout(0.1)
81
+
82
+ # Create a more powerful classification head
83
+ hidden_size = self.encoder.config.hidden_size
84
+
85
+ self.classifier = torch.nn.Sequential(
86
+ torch.nn.Linear(hidden_size, hidden_size),
87
+ torch.nn.LayerNorm(hidden_size),
88
+ torch.nn.GELU(),
89
+ torch.nn.Dropout(0.1),
90
+ torch.nn.Linear(hidden_size, 512),
91
+ torch.nn.LayerNorm(512),
92
+ torch.nn.GELU(),
93
+ torch.nn.Dropout(0.1),
94
+ torch.nn.Linear(512, num_labels)
95
+ )
96
+
97
+ def forward(self, input_ids, attention_mask):
98
+ outputs = self.encoder(
99
+ input_ids=input_ids,
100
+ attention_mask=attention_mask,
101
+ return_dict=True
102
+ )
103
+
104
+ pooled_output = outputs.pooler_output
105
+ logits = self.classifier(pooled_output)
106
+
107
+ return logits
108
+
109
+ def extract_features(source_code, test_code_1, test_code_2):
110
+ """Extract specific features to help the model identify similarities"""
111
+
112
+ # Extract test fixtures
113
+ fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1)
114
+ fixture1 = fixture1.group(1) if fixture1 else ""
115
+
116
+ fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2)
117
+ fixture2 = fixture2.group(1) if fixture2 else ""
118
+
119
+ # Extract test names
120
+ name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1)
121
+ name1 = name1.group(1) if name1 else ""
122
+
123
+ name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2)
124
+ name2 = name2.group(1) if name2 else ""
125
+
126
+ # Extract assertions
127
+ assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1)
128
+ assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2)
129
+
130
+ # Extract function/method calls
131
+ calls1 = re.findall(r'(\w+)\s*\(', test_code_1)
132
+ calls2 = re.findall(r'(\w+)\s*\(', test_code_2)
133
+
134
+ # Create explicit feature section
135
+ same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE"
136
+ common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2]))
137
+ common_calls = set(calls1).intersection(set(calls2))
138
+
139
+ # Calculate assertion ratio with safety check for zero
140
+ assertion_ratio = 0
141
+ if assertions1 and assertions2:
142
+ total_assertions = len(assertions1) + len(assertions2)
143
+ if total_assertions > 0:
144
+ assertion_ratio = len(common_assertions) / total_assertions
145
+
146
+ features = (
147
+ f"METADATA: {same_fixture} | "
148
+ f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | "
149
+ f"NAME1: {name1} | NAME2: {name2} | "
150
+ f"COMMON_ASSERTIONS: {len(common_assertions)} | "
151
+ f"COMMON_CALLS: {len(common_calls)} | "
152
+ f"ASSERTION_RATIO: {assertion_ratio}"
153
+ )
154
+
155
+ return features
156
+
157
  # Global variables for model and tokenizer
 
158
  tokenizer = None
159
+ model = None
160
+
161
+ def download_model_weights(url, save_path):
162
+ """Download model weights from URL to a local file"""
163
+ try:
164
+ logger.info(f"Downloading model weights from {url}...")
165
+ response = requests.get(url, stream=True)
166
+ if response.status_code != 200:
167
+ logger.error(f"Failed to download: HTTP {response.status_code}")
168
+ return False
169
+
170
+ with open(save_path, 'wb') as f:
171
+ for chunk in response.iter_content(chunk_size=8192):
172
+ if chunk:
173
+ f.write(chunk)
174
+ logger.info(f"Successfully downloaded model weights to {save_path}")
175
+ return True
176
+ except Exception as e:
177
+ logger.error(f"Error downloading model weights: {e}")
178
+ return False
179
 
180
  # Load model and tokenizer on startup
181
  @app.on_event("startup")
182
  async def startup_event():
183
+ global tokenizer, model
184
+
185
  try:
186
+ logger.info("=== Starting model loading process ===")
187
 
188
+ # Step 1: Load the tokenizer from the base model
189
+ logger.info(f"Loading tokenizer from microsoft/codebert-base...")
190
+ try:
191
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
192
+ logger.info("✅ Base tokenizer loaded successfully")
193
+ except Exception as e:
194
+ logger.error(f"❌ Failed to load tokenizer: {str(e)}")
195
+ raise
196
 
197
+ # Step 2: Create model with base architecture
198
+ logger.info("Creating model architecture...")
199
+ try:
200
+ # Initialize with base CodeBERT
201
+ model = CodeSimilarityClassifier(model_name="microsoft/codebert-base")
202
+ logger.info("✅ Model architecture created successfully")
203
+ except Exception as e:
204
+ logger.error(f"❌ Failed to create model architecture: {str(e)}")
205
+ raise
206
 
207
+ # Step 3: Download and load weights
208
+ model_path = "pytorch_model.bin"
209
 
210
+ # First check if the file already exists
211
+ if not os.path.exists(model_path):
212
+ # Try downloading
213
+ if not download_model_weights(MODEL_WEIGHTS_URL, model_path):
214
+ logger.error("❌ Failed to download model weights")
215
+ raise RuntimeError("Failed to download model weights")
216
+
217
+ # Try to load the model weights
218
+ try:
219
+ # Check if the weights are a state dict or the whole model
220
+ logger.info(f"Loading weights from {model_path}...")
221
+ checkpoint = torch.load(model_path, map_location=device)
222
+
223
+ if isinstance(checkpoint, dict):
224
+ # If it's a state dict directly
225
+ if "state_dict" in checkpoint:
226
+ logger.info("Loading from checkpoint['state_dict']")
227
+ model.load_state_dict(checkpoint["state_dict"])
228
+ elif "model_state_dict" in checkpoint:
229
+ logger.info("Loading from checkpoint['model_state_dict']")
230
+ model.load_state_dict(checkpoint["model_state_dict"])
231
+ else:
232
+ logger.info("Loading from checkpoint directly")
233
+ model.load_state_dict(checkpoint)
234
+ else:
235
+ logger.error("❌ Unsupported model format")
236
+ raise RuntimeError("Unsupported model format")
237
+
238
+ logger.info("✅ Model weights loaded successfully")
239
+ except Exception as e:
240
+ logger.error(f"❌ Error loading model weights: {str(e)}")
241
+ raise
242
 
243
  # Move model to device and set to evaluation mode
244
  model.to(device)
245
  model.eval()
246
+ logger.info(f"✅ Model moved to {device} and set to evaluation mode")
247
+ logger.info("=== Model loading process complete ===")
248
 
 
249
  except Exception as e:
250
+ logger.error(f" CRITICAL ERROR in startup: {str(e)}")
251
  import traceback
252
  logger.error(traceback.format_exc())
253
  model = None
254
  tokenizer = None
255
 
256
+ @app.get("/health")
257
  async def health_check():
258
  """Health check endpoint that also returns deployment information"""
259
+ model_status = model is not None
260
+ tokenizer_status = tokenizer is not None
261
+ status = "ok" if (model_status and tokenizer_status) else "error"
 
 
 
 
262
 
263
  return {
264
+ "status": status,
265
+ "model_loaded": model_status,
266
+ "tokenizer_loaded": tokenizer_status,
267
  "model": REPO_ID,
268
  "device": str(device),
269
  "deployment_date": DEPLOYMENT_DATE,
 
275
  async def predict(data: SimilarityInput):
276
  """
277
  Predict similarity class between two test cases for a given source class.
 
 
 
278
  """
279
+ if model is None or tokenizer is None:
280
  raise HTTPException(status_code=500, detail="Model not loaded correctly")
281
 
282
  try:
 
289
  # Check if we can determine similarity without using the model
290
  if class_1 and class_2 and class_1 != class_2:
291
  logger.info(f"Heuristic detection: Different target classes - Distinct")
292
+ model_prediction = 2 # Distinct
293
  probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
294
  elif method_1 and method_2 and not set(method_1).intersection(set(method_2)):
295
  logger.info(f"Heuristic detection: Different target methods - Distinct")
296
+ model_prediction = 2 # Distinct
297
  probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
298
  else:
299
  # No clear heuristic match, use the model
300
+ # Extract features to help with classification
301
+ features = extract_features(data.source_code.code, data.test_case_1.code, data.test_case_2.code)
302
+
303
+ # Format the input text with clear section markers as done during training
304
+ formatted_text = (
305
+ f"{features}\n\n"
306
+ f"SOURCE CODE:\n{data.source_code.code.strip()}\n\n"
307
+ f"TEST CASE 1:\n{data.test_case_1.code.strip()}\n\n"
308
+ f"TEST CASE 2:\n{data.test_case_2.code.strip()}"
309
  )
310
 
311
  # Tokenize input
312
+ inputs = tokenizer(
313
+ formatted_text,
314
+ return_tensors="pt",
315
+ padding="max_length",
316
+ truncation=True,
317
+ max_length=512
318
+ ).to(device)
319
 
320
+ # Model inference
321
  with torch.no_grad():
322
+ logits = model(
 
323
  input_ids=inputs["input_ids"],
324
  attention_mask=inputs["attention_mask"]
325
  )
 
327
  # Process results
328
  probs = F.softmax(logits, dim=-1)[0].cpu().tolist()
329
  model_prediction = torch.argmax(logits, dim=-1).item()
330
+ logger.info(f"Model prediction: {label_to_class[model_prediction]}")
 
 
 
331
 
332
  # Map prediction to class name
333
+ classification = label_to_class.get(model_prediction, "Unknown")
334
+
335
+ # For API compatibility, map the model outputs (0,1,2) to API scores (1,2,3)
336
+ api_score = model_prediction + 1
337
 
338
  return {
339
  "pair_id": data.pair_id,
340
  "test_case_1_name": data.test_case_1.name,
341
  "test_case_2_name": data.test_case_2.name,
342
  "similarity": {
343
+ "score": api_score,
344
  "classification": classification,
345
  },
346
  "probabilities": probs
 
353
  logger.error(error_trace)
354
  raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
355
 
356
+ # Root and example endpoints
357
+ @app.get("/")
358
+ async def root():
359
+ return {
360
+ "message": "Test Similarity Analyzer API",
361
+ "documentation": "/docs",
362
+ "deployment_date": DEPLOYMENT_DATE,
363
+ "deployed_by": DEPLOYED_BY
364
+ }
365
+
366
+ @app.get("/example", response_model=SimilarityInput)
367
  async def get_example():
368
  """Get an example input to test the API"""
369
  return SimilarityInput(
 
390
  )
391
  )
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  if __name__ == "__main__":
394
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": ["CodeSimilarityClassifier"],
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 768,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 3072,
9
+ "max_position_embeddings": 514,
10
+ "num_attention_heads": 12,
11
+ "num_hidden_layers": 12,
12
+ "type_vocab_size": 1,
13
+ "vocab_size": 50265,
14
+ "layer_norm_eps": 1e-5,
15
+ "pad_token_id": 1,
16
+ "bos_token_id": 0,
17
+ "eos_token_id": 2,
18
+ "model_type": "codebert",
19
+ "problem_type": "single_label_classification",
20
+ "num_labels": 3,
21
+ "classifier_dropout": 0.1,
22
+ "classifier_hidden_size": 512,
23
+ "classifier_layers": 2,
24
+ "classifier_activation": "gelu",
25
+ "base_model_name": "microsoft/codebert-base",
26
+ "feature_extraction": true,
27
+ "deployment_date": "2025-06-22 22:17:05",
28
+ "deployed_by": "habibaelbehairy"
29
+ }
model_definition.py CHANGED
@@ -1,33 +1,78 @@
1
  import torch
2
  import torch.nn as nn
3
  from transformers import AutoModel
 
4
 
5
- class MultitaskCodeSimilarityModel(nn.Module):
6
- def __init__(self, config, tokenizer):
7
  super().__init__()
8
- self.config = config
9
- self.tokenizer = tokenizer
10
- self.encoder = AutoModel.from_config(config)
11
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
12
-
13
- # For explanation generation
14
- self.decoder_embedding = nn.Linear(config.hidden_size, config.hidden_size)
15
- self.decoder = nn.GRU(
16
- input_size=config.hidden_size,
17
- hidden_size=config.hidden_size,
18
- batch_first=True
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
- self.explanation_head = nn.Linear(config.hidden_size, len(tokenizer))
21
-
22
- def forward(self, input_ids, attention_mask, explanation_ids=None, explanation_mask=None):
23
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
24
- pooled = outputs.last_hidden_state[:, 0]
25
- logits = self.classifier(pooled)
26
-
27
- explanation_logits = None
28
- if explanation_ids is not None:
29
- decoder_input = self.decoder_embedding(pooled).unsqueeze(1).expand(-1, explanation_ids.size(1), -1)
30
- decoder_outputs, _ = self.decoder(decoder_input)
31
- explanation_logits = self.explanation_head(decoder_outputs)
32
-
33
- return logits, explanation_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import AutoModel
4
+ import re
5
 
6
+ class CodeSimilarityClassifier(nn.Module):
7
+ def __init__(self, model_name="microsoft/codebert-base", num_labels=3):
8
  super().__init__()
9
+ self.encoder = AutoModel.from_pretrained(model_name)
10
+ self.dropout = nn.Dropout(0.1)
11
+
12
+ # Create a more powerful classification head
13
+ hidden_size = self.encoder.config.hidden_size
14
+
15
+ self.classifier = nn.Sequential(
16
+ nn.Linear(hidden_size, hidden_size),
17
+ nn.LayerNorm(hidden_size),
18
+ nn.GELU(),
19
+ nn.Dropout(0.1),
20
+ nn.Linear(hidden_size, 512),
21
+ nn.LayerNorm(512),
22
+ nn.GELU(),
23
+ nn.Dropout(0.1),
24
+ nn.Linear(512, num_labels)
25
+ )
26
+
27
+ def forward(self, input_ids, attention_mask):
28
+ outputs = self.encoder(
29
+ input_ids=input_ids,
30
+ attention_mask=attention_mask,
31
+ return_dict=True
32
  )
33
+
34
+ pooled_output = outputs.pooler_output
35
+ logits = self.classifier(pooled_output)
36
+
37
+ return logits
38
+
39
+ def extract_features(source_code, test_code_1, test_code_2):
40
+ """Extract specific features to help the model identify similarities"""
41
+
42
+ # Extract test fixtures
43
+ fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1)
44
+ fixture1 = fixture1.group(1) if fixture1 else ""
45
+
46
+ fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2)
47
+ fixture2 = fixture2.group(1) if fixture2 else ""
48
+
49
+ # Extract test names
50
+ name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1)
51
+ name1 = name1.group(1) if name1 else ""
52
+
53
+ name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2)
54
+ name2 = name2.group(1) if name2 else ""
55
+
56
+ # Extract assertions
57
+ assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1)
58
+ assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2)
59
+
60
+ # Extract function/method calls
61
+ calls1 = re.findall(r'(\w+)\s*\(', test_code_1)
62
+ calls2 = re.findall(r'(\w+)\s*\(', test_code_2)
63
+
64
+ # Create explicit feature section
65
+ same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE"
66
+ common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2]))
67
+ common_calls = set(calls1).intersection(set(calls2))
68
+
69
+ features = (
70
+ f"METADATA: {same_fixture} | "
71
+ f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | "
72
+ f"NAME1: {name1} | NAME2: {name2} | "
73
+ f"COMMON_ASSERTIONS: {len(common_assertions)} | "
74
+ f"COMMON_CALLS: {len(common_calls)} | "
75
+ f"ASSERTION_RATIO: {len(common_assertions)/(len(assertions1) + len(assertions2)) if assertions1 and assertions2 else 0}"
76
+ )
77
+
78
+ return features
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
- torch>=1.10.0
2
- transformers>=4.18.0
3
- fastapi>=0.68.0
4
- uvicorn>=0.15.0
5
- pydantic>=1.8.0
6
- numpy>=1.20.0
 
1
+ fastapi
2
+ torch
3
+ transformers
4
+ uvicorn
5
+ requests