File size: 1,814 Bytes
fa41e21
 
 
 
70b18ac
 
fa41e21
53b8b52
fa41e21
 
 
70b18ac
 
 
fa41e21
430236e
fa41e21
 
 
 
 
ba5e3fb
fa41e21
 
175968e
 
fa41e21
c0734d1
 
 
fa41e21
c0734d1
 
 
fa41e21
c0734d1
fa41e21
c0734d1
 
 
 
 
 
 
 
fa41e21
c0734d1
 
 
 
fa41e21
c0734d1
fa41e21
c0734d1
fa41e21
c0734d1
 
fa41e21
 
430236e
70b18ac
c3e3e66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer
import os
from transformers import RobertaForSequenceClassification
import torch.serialization
# Initialize Flask app
app = Flask(__name__)

# Load the trained model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
torch.serialization.add_safe_globals([RobertaForSequenceClassification])

model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False)  # Load the trained model

# Ensure the model is in evaluation mode
model.eval()


@app.route("/")
def home():
    return request.url


# @app.route("/predict", methods=["POST"])
@app.route("/predict")
def predict():
    try:
        # Debugging: print input code to check if the request is received correctly
        print("Received code:", request.get_json()["code"])

        data = request.get_json()
        if "code" not in data:
            return jsonify({"error": "Missing 'code' parameter"}), 400

        code_input = data["code"]

        # Tokenize the input code using the CodeBERT tokenizer
        inputs = tokenizer(
            code_input,
            return_tensors='pt',
            truncation=True,
            padding='max_length',
            max_length=512
        )

        # Make prediction using the model
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = outputs.logits.squeeze().item()  # Extract the predicted score (single float)

        print(f"Predicted score: {prediction}")  # Debugging: Print prediction

        return jsonify({"predicted_score": prediction})

    except Exception as e:
        return jsonify({"error": str(e)}), 500


# Run the Flask app
if __name__ == "__main__":
     app.run(host="0.0.0.0", port=7860)