ElPremOoO commited on
Commit
3aa7633
·
verified ·
1 Parent(s): c0734d1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -38
main.py CHANGED
@@ -1,62 +1,67 @@
1
  from flask import Flask, request, jsonify
2
  import torch
3
- from transformers import RobertaTokenizer
4
  import os
5
- from transformers import RobertaForSequenceClassification
6
- import torch.serialization
7
- # Initialize Flask app
8
- app = Flask(__name__)
9
-
10
- # Load the trained model and tokenizer
11
- tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
12
- torch.serialization.add_safe_globals([RobertaForSequenceClassification])
13
 
14
- model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False) # Load the trained model
15
-
16
- # Ensure the model is in evaluation mode
17
- model.eval()
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @app.route("/")
21
  def home():
22
- return request.url
23
 
24
-
25
- # @app.route("/predict", methods=["POST"])
26
- @app.route("/predict")
27
  def predict():
28
  try:
29
- # Debugging: print input code to check if the request is received correctly
30
- print("Received code:", request.get_json()["code"])
31
-
32
  data = request.get_json()
33
  if "code" not in data:
34
  return jsonify({"error": "Missing 'code' parameter"}), 400
35
-
36
- code_input = data["code"]
37
-
38
- # Tokenize the input code using the CodeBERT tokenizer
39
  inputs = tokenizer(
40
- code_input,
41
- return_tensors='pt',
42
  truncation=True,
43
  padding='max_length',
44
- max_length=512
 
45
  )
46
-
47
- # Make prediction using the model
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
- prediction = outputs.logits.squeeze().item() # Extract the predicted score (single float)
51
-
52
- print(f"Predicted score: {prediction}") # Debugging: Print prediction
53
-
54
- return jsonify({"predicted_score": prediction})
55
-
 
 
 
56
  except Exception as e:
57
  return jsonify({"error": str(e)}), 500
58
 
59
-
60
- # Run the Flask app
61
  if __name__ == "__main__":
62
- app.run(host="0.0.0.0", port=7860)
 
1
  from flask import Flask, request, jsonify
2
  import torch
3
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaConfig
4
  import os
 
 
 
 
 
 
 
 
5
 
6
+ app = Flask(__name__)
 
 
 
7
 
8
+ # Load model and tokenizer
9
+ def load_model():
10
+ # Load saved config and weights
11
+ checkpoint = torch.load("codebert_readability_scorer.pth", map_location=torch.device('cpu'))
12
+ config = RobertaConfig.from_dict(checkpoint['config'])
13
+
14
+ # Initialize model with loaded config
15
+ model = RobertaForSequenceClassification(config)
16
+ model.load_state_dict(checkpoint['model_state_dict'])
17
+ model.eval()
18
+ return model
19
+
20
+ # Load components
21
+ try:
22
+ tokenizer = RobertaTokenizer.from_pretrained("./tokenizer")
23
+ model = load_model()
24
+ print("Model and tokenizer loaded successfully!")
25
+ except Exception as e:
26
+ print(f"Error loading model: {str(e)}")
27
 
28
  @app.route("/")
29
  def home():
30
+ return "Code Readability Scoring API - Send POST request to /predict with code snippet"
31
 
32
+ @app.route("/predict", methods=["POST"])
 
 
33
  def predict():
34
  try:
35
+ # Get code from request
 
 
36
  data = request.get_json()
37
  if "code" not in data:
38
  return jsonify({"error": "Missing 'code' parameter"}), 400
39
+
40
+ code = data["code"]
41
+
42
+ # Tokenize input
43
  inputs = tokenizer(
44
+ code,
 
45
  truncation=True,
46
  padding='max_length',
47
+ max_length=512,
48
+ return_tensors='pt'
49
  )
50
+
51
+ # Make prediction
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
+
55
+ # Apply sigmoid and format score
56
+ score = torch.sigmoid(outputs.logits).item()
57
+
58
+ return jsonify({
59
+ "readability_score": round(score, 4),
60
+ "processed_code": code[:500] + "..." if len(code) > 500 else code
61
+ })
62
+
63
  except Exception as e:
64
  return jsonify({"error": str(e)}), 500
65
 
 
 
66
  if __name__ == "__main__":
67
+ app.run(host="0.0.0.0", port=7860)