File size: 8,727 Bytes
1306f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os
import time
import logging
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoConfig
from model_definition import MultitaskCodeSimilarityModel
from typing import List
import uvicorn
from datetime import datetime

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# System information - Updated with the provided values
DEPLOYMENT_DATE = "2025-06-10 15:11:04"  # Updated timestamp
DEPLOYED_BY = "Fastest"

# Get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Your Hugging Face model repository
REPO_ID = "FastestAI/Redundant_Model"

# Initialize FastAPI app
app = FastAPI(
    title="Test Similarity Analyzer API",
    description="API for analyzing similarity between test cases. Deployed by " + DEPLOYED_BY,
    version="1.0.0",
    docs_url="/",
)

# Add CORS middleware to allow cross-origin requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define label to class mapping with CORRECT NUMBERING (1, 2, 3 instead of 0, 1, 2)
label_to_class = {1: "Duplicate", 2: "Redundant", 3: "Distinct"}

# Model output to API label mapping (if your model outputs 0, 1, 2 but we want 1, 2, 3)
model_to_api_label = {0: 1, 1: 2, 2: 3}

# Define input models for API
class SourceCode(BaseModel):
    class_name: str
    code: str

class TestCase(BaseModel):
    id: str
    test_fixture: str
    name: str
    code: str
    target_class: str
    target_method: List[str]

class SimilarityInput(BaseModel):
    pair_id: str
    source_code: SourceCode
    test_case_1: TestCase
    test_case_2: TestCase

# Global variables for model and tokenizer
model = None
tokenizer = None

# Load model and tokenizer on startup
@app.on_event("startup")
async def startup_event():
    global model, tokenizer
    try:
        logger.info(f"Loading model and tokenizer from {REPO_ID}...")
        
        # Load tokenizer directly from Hugging Face
        tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
        
        # Load config from Hugging Face
        config = AutoConfig.from_pretrained(REPO_ID)
        
        # Create model instance using imported MultitaskCodeSimilarityModel class
        model = MultitaskCodeSimilarityModel(config, tokenizer)
        
        # Load weights directly from Hugging Face
        state_dict = torch.hub.load_state_dict_from_url(
            f"https://huggingface.co/{REPO_ID}/resolve/main/pytorch_model.bin",
            map_location=device,
            check_hash=False
        )
        model.load_state_dict(state_dict)
        
        # Move model to device and set to evaluation mode
        model.to(device)
        model.eval()
        
        logger.info("Model and tokenizer loaded successfully!")
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        import traceback
        logger.error(traceback.format_exc())
        model = None
        tokenizer = None

@app.get("/health", tags=["Health"])
async def health_check():
    """Health check endpoint that also returns deployment information"""
    if model is None or tokenizer is None:
        return {
            "status": "error", 
            "message": "Model or tokenizer not loaded",
            "deployment_date": DEPLOYMENT_DATE,
            "deployed_by": DEPLOYED_BY
        }
    
    return {
        "status": "ok", 
        "model": REPO_ID, 
        "device": str(device),
        "deployment_date": DEPLOYMENT_DATE,
        "deployed_by": DEPLOYED_BY,
        "current_time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
    }

@app.post("/predict")
async def predict(data: SimilarityInput):
    """
    Predict similarity class between two test cases for a given source class.
    
    Input schema follows the specified format with source_code, test_case_1, and test_case_2.
    Uses heuristics to detect class and method differences before using the model.
    """
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded correctly")
    
    try:
        # Apply heuristics for method and class differences
        class_1 = data.test_case_1.target_class
        class_2 = data.test_case_2.target_class
        method_1 = data.test_case_1.target_method
        method_2 = data.test_case_2.target_method
        
        # Check if we can determine similarity without using the model
        if class_1 and class_2 and class_1 != class_2:
            logger.info(f"Heuristic detection: Different target classes - Distinct")
            api_prediction = 3  # Distinct
            probs = [0.0, 0.0, 1.0]  # 100% confidence in Distinct
        elif method_1 and method_2 and not set(method_1).intersection(set(method_2)):
            logger.info(f"Heuristic detection: Different target methods - Distinct")
            api_prediction = 3  # Distinct
            probs = [0.0, 0.0, 1.0]  # 100% confidence in Distinct
        else:
            # No clear heuristic match, use the model
            # Format input to match training format
            combined_input = (
                f"SOURCE CODE: {data.source_code.code}\n"
                f"TEST 1: {data.test_case_1.code}\n"
                f"TEST 2: {data.test_case_2.code}"
            )

            # Tokenize input
            inputs = tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

            # THIS IS WHERE THE MODEL IS CALLED
            with torch.no_grad():
                # Our custom model
                logits, _ = model(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"]
                )

            # Process results
            probs = F.softmax(logits, dim=-1)[0].cpu().tolist()
            model_prediction = torch.argmax(logits, dim=-1).item()
            
            # Convert model prediction (0,1,2) to API prediction (1,2,3)
            api_prediction = model_to_api_label[model_prediction]
            logger.info(f"Model prediction: {label_to_class[api_prediction]}")
        
        # Map prediction to class name
        classification = label_to_class.get(api_prediction, "Unknown")
        
        return {
            "pair_id": data.pair_id,
            "test_case_1_name": data.test_case_1.name,
            "test_case_2_name": data.test_case_2.name,
            "similarity": {
                "score": api_prediction,
                "classification": classification,
            },
            "probabilities": probs
        }
    
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        logger.error(f"Prediction error: {str(e)}")
        logger.error(error_trace)
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

# Example endpoint
@app.get("/example", response_model=SimilarityInput, tags=["Examples"])
async def get_example():
    """Get an example input to test the API"""
    return SimilarityInput(
        pair_id="example-1",
        source_code=SourceCode(
            class_name="Calculator",
            code="class Calculator {\n    public int add(int a, int b) {\n        return a + b;\n    }\n}"
        ),
        test_case_1=TestCase(
            id="test-1",
            test_fixture="CalculatorTest",
            name="testAddsTwoPositiveNumbers",
            code="TEST(CalculatorTest, AddsTwoPositiveNumbers) {\n    Calculator calc;\n    EXPECT_EQ(5, calc.add(2, 3));\n}",
            target_class="Calculator",
            target_method=["add"]
        ),
        test_case_2=TestCase(
            id="test-2",
            test_fixture="CalculatorTest",
            name="testAddsTwoPositiveIntegers",
            code="TEST(CalculatorTest, AddsTwoPositiveIntegers) {\n    Calculator calc;\n    EXPECT_EQ(5, calc.add(2, 3));\n}",
            target_class="Calculator",
            target_method=["add"]
        )
    )

@app.get("/", tags=["Root"])
async def root():
    """
    Redirect to the API documentation.
    This is a convenience endpoint that redirects to the auto-generated docs.
    """
    return {
        "message": "Test Similarity Analyzer API",
        "documentation": "/docs",
        "deployment_date": DEPLOYMENT_DATE,
        "deployed_by": DEPLOYED_BY
    }

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)