Habiba A. Elbehairy
Refactor Code Similarity Classifier and update Dockerfile, README, and requirements
a5cd505
import os | |
import logging | |
import torch | |
import torch.nn.functional as F | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List | |
import uvicorn | |
from datetime import datetime | |
from transformers import AutoTokenizer, AutoModel | |
import requests | |
import re | |
import tempfile | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
# System information - with your current values | |
DEPLOYMENT_DATE = "2025-06-22 22:15:13" | |
DEPLOYED_BY = "FASTESTAI" | |
# Get device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {device}") | |
# HuggingFace model repository path just for weights file | |
REPO_ID = "FastestAI/Redundant_Model" | |
MODEL_WEIGHTS_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin" | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Test Similarity Analyzer API", | |
description="API for analyzing similarity between test cases. Deployed by " + DEPLOYED_BY, | |
version="1.0.0", | |
docs_url="/", | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Define label to class mapping | |
label_to_class = {0: "Duplicate", 1: "Redundant", 2: "Distinct"} | |
# Define input models for API | |
class SourceCode(BaseModel): | |
class_name: str | |
code: str | |
class TestCase(BaseModel): | |
id: str | |
test_fixture: str | |
name: str | |
code: str | |
target_class: str | |
target_method: List[str] | |
class SimilarityInput(BaseModel): | |
pair_id: str | |
source_code: SourceCode | |
test_case_1: TestCase | |
test_case_2: TestCase | |
# Define the model class | |
class CodeSimilarityClassifier(torch.nn.Module): | |
def __init__(self, model_name="microsoft/codebert-base", num_labels=3): | |
super().__init__() | |
self.encoder = AutoModel.from_pretrained(model_name) | |
self.dropout = torch.nn.Dropout(0.1) | |
# Create a more powerful classification head | |
hidden_size = self.encoder.config.hidden_size | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Linear(hidden_size, hidden_size), | |
torch.nn.LayerNorm(hidden_size), | |
torch.nn.GELU(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear(hidden_size, 512), | |
torch.nn.LayerNorm(512), | |
torch.nn.GELU(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear(512, num_labels) | |
) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=True | |
) | |
pooled_output = outputs.pooler_output | |
logits = self.classifier(pooled_output) | |
return logits | |
def extract_features(source_code, test_code_1, test_code_2): | |
"""Extract specific features to help the model identify similarities""" | |
# Extract test fixtures | |
fixture1 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_1) | |
fixture1 = fixture1.group(1) if fixture1 else "" | |
fixture2 = re.search(r'TEST(?:_F)?\s*\(\s*(\w+)', test_code_2) | |
fixture2 = fixture2.group(1) if fixture2 else "" | |
# Extract test names | |
name1 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_1) | |
name1 = name1.group(1) if name1 else "" | |
name2 = re.search(r'TEST(?:_F)?\s*\(\s*\w+\s*,\s*(\w+)', test_code_2) | |
name2 = name2.group(1) if name2 else "" | |
# Extract assertions | |
assertions1 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_1) | |
assertions2 = re.findall(r'(EXPECT_|ASSERT_)(\w+)', test_code_2) | |
# Extract function/method calls | |
calls1 = re.findall(r'(\w+)\s*\(', test_code_1) | |
calls2 = re.findall(r'(\w+)\s*\(', test_code_2) | |
# Create explicit feature section | |
same_fixture = "SAME_FIXTURE" if fixture1 == fixture2 else "DIFFERENT_FIXTURE" | |
common_assertions = set([a[0] + a[1] for a in assertions1]).intersection(set([a[0] + a[1] for a in assertions2])) | |
common_calls = set(calls1).intersection(set(calls2)) | |
# Calculate assertion ratio with safety check for zero | |
assertion_ratio = 0 | |
if assertions1 and assertions2: | |
total_assertions = len(assertions1) + len(assertions2) | |
if total_assertions > 0: | |
assertion_ratio = len(common_assertions) / total_assertions | |
features = ( | |
f"METADATA: {same_fixture} | " | |
f"FIXTURE1: {fixture1} | FIXTURE2: {fixture2} | " | |
f"NAME1: {name1} | NAME2: {name2} | " | |
f"COMMON_ASSERTIONS: {len(common_assertions)} | " | |
f"COMMON_CALLS: {len(common_calls)} | " | |
f"ASSERTION_RATIO: {assertion_ratio}" | |
) | |
return features | |
# Global variables for model and tokenizer | |
tokenizer = None | |
model = None | |
def download_model_weights(url, save_path): | |
"""Download model weights from URL to a local file""" | |
try: | |
logger.info(f"Downloading model weights from {url}...") | |
response = requests.get(url, stream=True) | |
if response.status_code != 200: | |
logger.error(f"Failed to download: HTTP {response.status_code}") | |
return False | |
with open(save_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
logger.info(f"Successfully downloaded model weights to {save_path}") | |
return True | |
except Exception as e: | |
logger.error(f"Error downloading model weights: {e}") | |
return False | |
# Load model and tokenizer on startup | |
async def startup_event(): | |
global tokenizer, model | |
try: | |
logger.info("=== Starting model loading process ===") | |
# Step 1: Load the tokenizer from the base model | |
logger.info(f"Loading tokenizer from microsoft/codebert-base...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") | |
logger.info("✅ Base tokenizer loaded successfully") | |
except Exception as e: | |
logger.error(f"❌ Failed to load tokenizer: {str(e)}") | |
raise | |
# Step 2: Create model with base architecture | |
logger.info("Creating model architecture...") | |
try: | |
# Initialize with base CodeBERT | |
model = CodeSimilarityClassifier(model_name="microsoft/codebert-base") | |
logger.info("✅ Model architecture created successfully") | |
except Exception as e: | |
logger.error(f"❌ Failed to create model architecture: {str(e)}") | |
raise | |
# Step 3: Download and load weights | |
model_path = "pytorch_model.bin" | |
# First check if the file already exists | |
if not os.path.exists(model_path): | |
# Try downloading | |
if not download_model_weights(MODEL_WEIGHTS_URL, model_path): | |
logger.error("❌ Failed to download model weights") | |
raise RuntimeError("Failed to download model weights") | |
# Try to load the model weights | |
try: | |
# Check if the weights are a state dict or the whole model | |
logger.info(f"Loading weights from {model_path}...") | |
checkpoint = torch.load(model_path, map_location=device) | |
if isinstance(checkpoint, dict): | |
# If it's a state dict directly | |
if "state_dict" in checkpoint: | |
logger.info("Loading from checkpoint['state_dict']") | |
model.load_state_dict(checkpoint["state_dict"]) | |
elif "model_state_dict" in checkpoint: | |
logger.info("Loading from checkpoint['model_state_dict']") | |
model.load_state_dict(checkpoint["model_state_dict"]) | |
else: | |
logger.info("Loading from checkpoint directly") | |
model.load_state_dict(checkpoint) | |
else: | |
logger.error("❌ Unsupported model format") | |
raise RuntimeError("Unsupported model format") | |
logger.info("✅ Model weights loaded successfully") | |
except Exception as e: | |
logger.error(f"❌ Error loading model weights: {str(e)}") | |
raise | |
# Move model to device and set to evaluation mode | |
model.to(device) | |
model.eval() | |
logger.info(f"✅ Model moved to {device} and set to evaluation mode") | |
logger.info("=== Model loading process complete ===") | |
except Exception as e: | |
logger.error(f"❌ CRITICAL ERROR in startup: {str(e)}") | |
import traceback | |
logger.error(traceback.format_exc()) | |
model = None | |
tokenizer = None | |
async def health_check(): | |
"""Health check endpoint that also returns deployment information""" | |
model_status = model is not None | |
tokenizer_status = tokenizer is not None | |
status = "ok" if (model_status and tokenizer_status) else "error" | |
return { | |
"status": status, | |
"model_loaded": model_status, | |
"tokenizer_loaded": tokenizer_status, | |
"model": REPO_ID, | |
"device": str(device), | |
"deployment_date": DEPLOYMENT_DATE, | |
"deployed_by": DEPLOYED_BY, | |
"current_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
async def predict(data: SimilarityInput): | |
""" | |
Predict similarity class between two test cases for a given source class. | |
""" | |
if model is None or tokenizer is None: | |
raise HTTPException(status_code=500, detail="Model not loaded correctly") | |
try: | |
# Apply heuristics for method and class differences | |
class_1 = data.test_case_1.target_class | |
class_2 = data.test_case_2.target_class | |
method_1 = data.test_case_1.target_method | |
method_2 = data.test_case_2.target_method | |
# Check if we can determine similarity without using the model | |
if class_1 and class_2 and class_1 != class_2: | |
logger.info(f"Heuristic detection: Different target classes - Distinct") | |
model_prediction = 2 # Distinct | |
probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct | |
elif method_1 and method_2 and not set(method_1).intersection(set(method_2)): | |
logger.info(f"Heuristic detection: Different target methods - Distinct") | |
model_prediction = 2 # Distinct | |
probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct | |
else: | |
# No clear heuristic match, use the model | |
# Extract features to help with classification | |
features = extract_features(data.source_code.code, data.test_case_1.code, data.test_case_2.code) | |
# Format the input text with clear section markers as done during training | |
formatted_text = ( | |
f"{features}\n\n" | |
f"SOURCE CODE:\n{data.source_code.code.strip()}\n\n" | |
f"TEST CASE 1:\n{data.test_case_1.code.strip()}\n\n" | |
f"TEST CASE 2:\n{data.test_case_2.code.strip()}" | |
) | |
# Tokenize input | |
inputs = tokenizer( | |
formatted_text, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=512 | |
).to(device) | |
# Model inference | |
with torch.no_grad(): | |
logits = model( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"] | |
) | |
# Process results | |
probs = F.softmax(logits, dim=-1)[0].cpu().tolist() | |
model_prediction = torch.argmax(logits, dim=-1).item() | |
logger.info(f"Model prediction: {label_to_class[model_prediction]}") | |
# Map prediction to class name | |
classification = label_to_class.get(model_prediction, "Unknown") | |
# For API compatibility, map the model outputs (0,1,2) to API scores (1,2,3) | |
api_score = model_prediction + 1 | |
return { | |
"pair_id": data.pair_id, | |
"test_case_1_name": data.test_case_1.name, | |
"test_case_2_name": data.test_case_2.name, | |
"similarity": { | |
"score": api_score, | |
"classification": classification, | |
}, | |
"probabilities": probs | |
} | |
except Exception as e: | |
import traceback | |
error_trace = traceback.format_exc() | |
logger.error(f"Prediction error: {str(e)}") | |
logger.error(error_trace) | |
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
# Root and example endpoints | |
async def root(): | |
return { | |
"message": "Test Similarity Analyzer API", | |
"documentation": "/docs", | |
"deployment_date": DEPLOYMENT_DATE, | |
"deployed_by": DEPLOYED_BY | |
} | |
async def get_example(): | |
"""Get an example input to test the API""" | |
return SimilarityInput( | |
pair_id="example-1", | |
source_code=SourceCode( | |
class_name="Calculator", | |
code="class Calculator {\n public int add(int a, int b) {\n return a + b;\n }\n}" | |
), | |
test_case_1=TestCase( | |
id="test-1", | |
test_fixture="CalculatorTest", | |
name="testAddsTwoPositiveNumbers", | |
code="TEST(CalculatorTest, AddsTwoPositiveNumbers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", | |
target_class="Calculator", | |
target_method=["add"] | |
), | |
test_case_2=TestCase( | |
id="test-2", | |
test_fixture="CalculatorTest", | |
name="testAddsTwoPositiveIntegers", | |
code="TEST(CalculatorTest, AddsTwoPositiveIntegers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", | |
target_class="Calculator", | |
target_method=["add"] | |
) | |
) | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |