cheesecz commited on
Commit
13ec7ec
·
verified ·
1 Parent(s): f78e601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import os
2
-
3
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
4
-
5
  from flask import Flask, request, jsonify
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  import torch
 
 
 
 
8
 
9
  app = Flask(__name__)
10
 
@@ -32,18 +32,13 @@ def fuzzy_formality(score, threshold=0.75):
32
 
33
  @app.route("/predict", methods=["POST"])
34
  def predict_formality():
35
- # Get input text from request
36
  text = request.json.get("text")
37
  if not text:
38
  return jsonify({"error": "Text input is required"}), 400
39
 
40
  # Tokenize input
41
  encoding = tokenizer(
42
- text,
43
- add_special_tokens=True,
44
- truncation=True,
45
- padding="max_length",
46
- return_tensors="pt"
47
  )
48
 
49
  # Get predictions
@@ -52,7 +47,7 @@ def predict_formality():
52
 
53
  # Extract formality score
54
  softmax_scores = output.logits.softmax(dim=1)
55
- formality_score = softmax_scores[:, 1].item() # Extract formal score
56
 
57
  # Classify using fuzzy logic
58
  result = fuzzy_formality(formality_score)
@@ -63,5 +58,6 @@ def predict_formality():
63
  **result
64
  })
65
 
 
66
  if __name__ == "__main__":
67
  app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import os
5
+
6
+ # Set writable cache
7
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
8
 
9
  app = Flask(__name__)
10
 
 
32
 
33
  @app.route("/predict", methods=["POST"])
34
  def predict_formality():
 
35
  text = request.json.get("text")
36
  if not text:
37
  return jsonify({"error": "Text input is required"}), 400
38
 
39
  # Tokenize input
40
  encoding = tokenizer(
41
+ text, add_special_tokens=True, truncation=True, padding="max_length", return_tensors="pt"
 
 
 
 
42
  )
43
 
44
  # Get predictions
 
47
 
48
  # Extract formality score
49
  softmax_scores = output.logits.softmax(dim=1)
50
+ formality_score = softmax_scores[:, 1].item()
51
 
52
  # Classify using fuzzy logic
53
  result = fuzzy_formality(formality_score)
 
58
  **result
59
  })
60
 
61
+ # Run on correct port
62
  if __name__ == "__main__":
63
  app.run(host="0.0.0.0", port=7860)