Habiba A. Elbehairy
App
1306f0a
raw
history blame
8.73 kB
import os
import time
import logging
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoConfig
from model_definition import MultitaskCodeSimilarityModel
from typing import List
import uvicorn
from datetime import datetime
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# System information - Updated with the provided values
DEPLOYMENT_DATE = "2025-06-10 15:11:04" # Updated timestamp
DEPLOYED_BY = "Fastest"
# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Your Hugging Face model repository
REPO_ID = "FastestAI/Redundant_Model"
# Initialize FastAPI app
app = FastAPI(
title="Test Similarity Analyzer API",
description="API for analyzing similarity between test cases. Deployed by " + DEPLOYED_BY,
version="1.0.0",
docs_url="/",
)
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define label to class mapping with CORRECT NUMBERING (1, 2, 3 instead of 0, 1, 2)
label_to_class = {1: "Duplicate", 2: "Redundant", 3: "Distinct"}
# Model output to API label mapping (if your model outputs 0, 1, 2 but we want 1, 2, 3)
model_to_api_label = {0: 1, 1: 2, 2: 3}
# Define input models for API
class SourceCode(BaseModel):
class_name: str
code: str
class TestCase(BaseModel):
id: str
test_fixture: str
name: str
code: str
target_class: str
target_method: List[str]
class SimilarityInput(BaseModel):
pair_id: str
source_code: SourceCode
test_case_1: TestCase
test_case_2: TestCase
# Global variables for model and tokenizer
model = None
tokenizer = None
# Load model and tokenizer on startup
@app.on_event("startup")
async def startup_event():
global model, tokenizer
try:
logger.info(f"Loading model and tokenizer from {REPO_ID}...")
# Load tokenizer directly from Hugging Face
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
# Load config from Hugging Face
config = AutoConfig.from_pretrained(REPO_ID)
# Create model instance using imported MultitaskCodeSimilarityModel class
model = MultitaskCodeSimilarityModel(config, tokenizer)
# Load weights directly from Hugging Face
state_dict = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin",
map_location=device,
check_hash=False
)
model.load_state_dict(state_dict)
# Move model to device and set to evaluation mode
model.to(device)
model.eval()
logger.info("Model and tokenizer loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {e}")
import traceback
logger.error(traceback.format_exc())
model = None
tokenizer = None
@app.get("/health", tags=["Health"])
async def health_check():
"""Health check endpoint that also returns deployment information"""
if model is None or tokenizer is None:
return {
"status": "error",
"message": "Model or tokenizer not loaded",
"deployment_date": DEPLOYMENT_DATE,
"deployed_by": DEPLOYED_BY
}
return {
"status": "ok",
"model": REPO_ID,
"device": str(device),
"deployment_date": DEPLOYMENT_DATE,
"deployed_by": DEPLOYED_BY,
"current_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
}
@app.post("/predict")
async def predict(data: SimilarityInput):
"""
Predict similarity class between two test cases for a given source class.
Input schema follows the specified format with source_code, test_case_1, and test_case_2.
Uses heuristics to detect class and method differences before using the model.
"""
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded correctly")
try:
# Apply heuristics for method and class differences
class_1 = data.test_case_1.target_class
class_2 = data.test_case_2.target_class
method_1 = data.test_case_1.target_method
method_2 = data.test_case_2.target_method
# Check if we can determine similarity without using the model
if class_1 and class_2 and class_1 != class_2:
logger.info(f"Heuristic detection: Different target classes - Distinct")
api_prediction = 3 # Distinct
probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
elif method_1 and method_2 and not set(method_1).intersection(set(method_2)):
logger.info(f"Heuristic detection: Different target methods - Distinct")
api_prediction = 3 # Distinct
probs = [0.0, 0.0, 1.0] # 100% confidence in Distinct
else:
# No clear heuristic match, use the model
# Format input to match training format
combined_input = (
f"SOURCE CODE: {data.source_code.code}\n"
f"TEST 1: {data.test_case_1.code}\n"
f"TEST 2: {data.test_case_2.code}"
)
# Tokenize input
inputs = tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
# THIS IS WHERE THE MODEL IS CALLED
with torch.no_grad():
# Our custom model
logits, _ = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"]
)
# Process results
probs = F.softmax(logits, dim=-1)[0].cpu().tolist()
model_prediction = torch.argmax(logits, dim=-1).item()
# Convert model prediction (0,1,2) to API prediction (1,2,3)
api_prediction = model_to_api_label[model_prediction]
logger.info(f"Model prediction: {label_to_class[api_prediction]}")
# Map prediction to class name
classification = label_to_class.get(api_prediction, "Unknown")
return {
"pair_id": data.pair_id,
"test_case_1_name": data.test_case_1.name,
"test_case_2_name": data.test_case_2.name,
"similarity": {
"score": api_prediction,
"classification": classification,
},
"probabilities": probs
}
except Exception as e:
import traceback
error_trace = traceback.format_exc()
logger.error(f"Prediction error: {str(e)}")
logger.error(error_trace)
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
# Example endpoint
@app.get("/example", response_model=SimilarityInput, tags=["Examples"])
async def get_example():
"""Get an example input to test the API"""
return SimilarityInput(
pair_id="example-1",
source_code=SourceCode(
class_name="Calculator",
code="class Calculator {\n public int add(int a, int b) {\n return a + b;\n }\n}"
),
test_case_1=TestCase(
id="test-1",
test_fixture="CalculatorTest",
name="testAddsTwoPositiveNumbers",
code="TEST(CalculatorTest, AddsTwoPositiveNumbers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}",
target_class="Calculator",
target_method=["add"]
),
test_case_2=TestCase(
id="test-2",
test_fixture="CalculatorTest",
name="testAddsTwoPositiveIntegers",
code="TEST(CalculatorTest, AddsTwoPositiveIntegers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}",
target_class="Calculator",
target_method=["add"]
)
)
@app.get("/", tags=["Root"])
async def root():
"""
Redirect to the API documentation.
This is a convenience endpoint that redirects to the auto-generated docs.
"""
return {
"message": "Test Similarity Analyzer API",
"documentation": "/docs",
"deployment_date": DEPLOYMENT_DATE,
"deployed_by": DEPLOYED_BY
}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)