File size: 1,789 Bytes
fa41e21
 
 
 
70b18ac
 
fa41e21
53b8b52
fa41e21
 
 
70b18ac
 
 
fa41e21
430236e
fa41e21
 
 
 
 
ba5e3fb
fa41e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430236e
70b18ac
c3e3e66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer
import os
from transformers import RobertaForSequenceClassification
import torch.serialization
# Initialize Flask app
app = Flask(__name__)

# Load the trained model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
torch.serialization.add_safe_globals([RobertaForSequenceClassification])

model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False)  # Load the trained model

# Ensure the model is in evaluation mode
model.eval()


@app.route("/")
def home():
    return request.url


@app.route("/predict", methods=["POST"])
def predict():
    try:
        # Debugging: print input code to check if the request is received correctly
        print("Received code:", request.get_json()["code"])

        data = request.get_json()
        if "code" not in data:
            return jsonify({"error": "Missing 'code' parameter"}), 400

        code_input = data["code"]

        # Tokenize the input code using the CodeBERT tokenizer
        inputs = tokenizer(
            code_input,
            return_tensors='pt',
            truncation=True,
            padding='max_length',
            max_length=512
        )

        # Make prediction using the model
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = outputs.logits.squeeze().item()  # Extract the predicted score (single float)

        print(f"Predicted score: {prediction}")  # Debugging: Print prediction

        return jsonify({"predicted_score": prediction})

    except Exception as e:
        return jsonify({"error": str(e)}), 500


# Run the Flask app
if __name__ == "__main__":
     app.run(host="0.0.0.0", port=7860)