Spaces:
Running
Running
File size: 1,814 Bytes
fa41e21 70b18ac fa41e21 53b8b52 fa41e21 70b18ac fa41e21 430236e fa41e21 ba5e3fb fa41e21 175968e fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 c0734d1 fa41e21 430236e 70b18ac c3e3e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer
import os
from transformers import RobertaForSequenceClassification
import torch.serialization
# Initialize Flask app
app = Flask(__name__)
# Load the trained model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
torch.serialization.add_safe_globals([RobertaForSequenceClassification])
model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False) # Load the trained model
# Ensure the model is in evaluation mode
model.eval()
@app.route("/")
def home():
return request.url
# @app.route("/predict", methods=["POST"])
@app.route("/predict")
def predict():
try:
# Debugging: print input code to check if the request is received correctly
print("Received code:", request.get_json()["code"])
data = request.get_json()
if "code" not in data:
return jsonify({"error": "Missing 'code' parameter"}), 400
code_input = data["code"]
# Tokenize the input code using the CodeBERT tokenizer
inputs = tokenizer(
code_input,
return_tensors='pt',
truncation=True,
padding='max_length',
max_length=512
)
# Make prediction using the model
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.squeeze().item() # Extract the predicted score (single float)
print(f"Predicted score: {prediction}") # Debugging: Print prediction
return jsonify({"predicted_score": prediction})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the Flask app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |