|
import os |
|
import time |
|
import logging |
|
import torch |
|
import torch.nn.functional as F |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoConfig |
|
from model_definition import MultitaskCodeSimilarityModel |
|
from typing import List |
|
import uvicorn |
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DEPLOYMENT_DATE = "2025-06-10 15:11:04" |
|
DEPLOYED_BY = "Fastest" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
REPO_ID = "FastestAI/Redundant_Model" |
|
|
|
|
|
app = FastAPI( |
|
title="Test Similarity Analyzer API", |
|
description="API for analyzing similarity between test cases. Deployed by " + DEPLOYED_BY, |
|
version="1.0.0", |
|
docs_url="/", |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
label_to_class = {1: "Duplicate", 2: "Redundant", 3: "Distinct"} |
|
|
|
|
|
model_to_api_label = {0: 1, 1: 2, 2: 3} |
|
|
|
|
|
class SourceCode(BaseModel): |
|
class_name: str |
|
code: str |
|
|
|
class TestCase(BaseModel): |
|
id: str |
|
test_fixture: str |
|
name: str |
|
code: str |
|
target_class: str |
|
target_method: List[str] |
|
|
|
class SimilarityInput(BaseModel): |
|
pair_id: str |
|
source_code: SourceCode |
|
test_case_1: TestCase |
|
test_case_2: TestCase |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
global model, tokenizer |
|
try: |
|
logger.info(f"Loading model and tokenizer from {REPO_ID}...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(REPO_ID) |
|
|
|
|
|
config = AutoConfig.from_pretrained(REPO_ID) |
|
|
|
|
|
model = MultitaskCodeSimilarityModel(config, tokenizer) |
|
|
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin", |
|
map_location=device, |
|
check_hash=False |
|
) |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
logger.info("Model and tokenizer loaded successfully!") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
import traceback |
|
logger.error(traceback.format_exc()) |
|
model = None |
|
tokenizer = None |
|
|
|
@app.get("/health", tags=["Health"]) |
|
async def health_check(): |
|
"""Health check endpoint that also returns deployment information""" |
|
if model is None or tokenizer is None: |
|
return { |
|
"status": "error", |
|
"message": "Model or tokenizer not loaded", |
|
"deployment_date": DEPLOYMENT_DATE, |
|
"deployed_by": DEPLOYED_BY |
|
} |
|
|
|
return { |
|
"status": "ok", |
|
"model": REPO_ID, |
|
"device": str(device), |
|
"deployment_date": DEPLOYMENT_DATE, |
|
"deployed_by": DEPLOYED_BY, |
|
"current_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") |
|
} |
|
|
|
@app.post("/predict") |
|
async def predict(data: SimilarityInput): |
|
""" |
|
Predict similarity class between two test cases for a given source class. |
|
|
|
Input schema follows the specified format with source_code, test_case_1, and test_case_2. |
|
Uses heuristics to detect class and method differences before using the model. |
|
""" |
|
if model is None: |
|
raise HTTPException(status_code=500, detail="Model not loaded correctly") |
|
|
|
try: |
|
|
|
class_1 = data.test_case_1.target_class |
|
class_2 = data.test_case_2.target_class |
|
method_1 = data.test_case_1.target_method |
|
method_2 = data.test_case_2.target_method |
|
|
|
|
|
if class_1 and class_2 and class_1 != class_2: |
|
logger.info(f"Heuristic detection: Different target classes - Distinct") |
|
api_prediction = 3 |
|
probs = [0.0, 0.0, 1.0] |
|
elif method_1 and method_2 and not set(method_1).intersection(set(method_2)): |
|
logger.info(f"Heuristic detection: Different target methods - Distinct") |
|
api_prediction = 3 |
|
probs = [0.0, 0.0, 1.0] |
|
else: |
|
|
|
|
|
combined_input = ( |
|
f"SOURCE CODE: {data.source_code.code}\n" |
|
f"TEST 1: {data.test_case_1.code}\n" |
|
f"TEST 2: {data.test_case_2.code}" |
|
) |
|
|
|
|
|
inputs = tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
logits, _ = model( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"] |
|
) |
|
|
|
|
|
probs = F.softmax(logits, dim=-1)[0].cpu().tolist() |
|
model_prediction = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
api_prediction = model_to_api_label[model_prediction] |
|
logger.info(f"Model prediction: {label_to_class[api_prediction]}") |
|
|
|
|
|
classification = label_to_class.get(api_prediction, "Unknown") |
|
|
|
return { |
|
"pair_id": data.pair_id, |
|
"test_case_1_name": data.test_case_1.name, |
|
"test_case_2_name": data.test_case_2.name, |
|
"similarity": { |
|
"score": api_prediction, |
|
"classification": classification, |
|
}, |
|
"probabilities": probs |
|
} |
|
|
|
except Exception as e: |
|
import traceback |
|
error_trace = traceback.format_exc() |
|
logger.error(f"Prediction error: {str(e)}") |
|
logger.error(error_trace) |
|
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
|
|
|
@app.get("/example", response_model=SimilarityInput, tags=["Examples"]) |
|
async def get_example(): |
|
"""Get an example input to test the API""" |
|
return SimilarityInput( |
|
pair_id="example-1", |
|
source_code=SourceCode( |
|
class_name="Calculator", |
|
code="class Calculator {\n public int add(int a, int b) {\n return a + b;\n }\n}" |
|
), |
|
test_case_1=TestCase( |
|
id="test-1", |
|
test_fixture="CalculatorTest", |
|
name="testAddsTwoPositiveNumbers", |
|
code="TEST(CalculatorTest, AddsTwoPositiveNumbers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", |
|
target_class="Calculator", |
|
target_method=["add"] |
|
), |
|
test_case_2=TestCase( |
|
id="test-2", |
|
test_fixture="CalculatorTest", |
|
name="testAddsTwoPositiveIntegers", |
|
code="TEST(CalculatorTest, AddsTwoPositiveIntegers) {\n Calculator calc;\n EXPECT_EQ(5, calc.add(2, 3));\n}", |
|
target_class="Calculator", |
|
target_method=["add"] |
|
) |
|
) |
|
|
|
@app.get("/", tags=["Root"]) |
|
async def root(): |
|
""" |
|
Redirect to the API documentation. |
|
This is a convenience endpoint that redirects to the auto-generated docs. |
|
""" |
|
return { |
|
"message": "Test Similarity Analyzer API", |
|
"documentation": "/docs", |
|
"deployment_date": DEPLOYMENT_DATE, |
|
"deployed_by": DEPLOYED_BY |
|
} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |