cheesecz commited on
Commit
4d95560
·
verified ·
1 Parent(s): 8da26a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ app = Flask(__name__)
6
+
7
+ # Load the pretrained model and tokenizer
8
+ model_name = "s-nlp/roberta-base-formality-ranker"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
+
12
+ # Fuzzy classification function
13
+ def fuzzy_formality(score, threshold=0.75):
14
+ if score < threshold:
15
+ formal_weight = 0.5 * (score / threshold) ** 2
16
+ else:
17
+ formal_weight = 1 - 0.5 * ((1 - score) / (1 - threshold)) ** 2
18
+
19
+ informal_weight = 1 - formal_weight
20
+ formal_percent = round(formal_weight * 100)
21
+ informal_percent = round(informal_weight * 100)
22
+
23
+ return {
24
+ "formal_percent": formal_percent,
25
+ "informal_percent": informal_percent,
26
+ "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
27
+ }
28
+
29
+ @app.route("/predict", methods=["POST"])
30
+ def predict_formality():
31
+ # Get input text from request
32
+ text = request.json.get("text")
33
+ if not text:
34
+ return jsonify({"error": "Text input is required"}), 400
35
+
36
+ # Tokenize input
37
+ encoding = tokenizer(
38
+ text,
39
+ add_special_tokens=True,
40
+ truncation=True,
41
+ padding="max_length",
42
+ return_tensors="pt"
43
+ )
44
+
45
+ # Get predictions
46
+ with torch.no_grad():
47
+ output = model(**encoding)
48
+
49
+ # Extract formality score
50
+ softmax_scores = output.logits.softmax(dim=1)
51
+ formality_score = softmax_scores[:, 1].item() # Extract formal score
52
+
53
+ # Classify using fuzzy logic
54
+ result = fuzzy_formality(formality_score)
55
+
56
+ return jsonify({
57
+ "text": text,
58
+ "formality_score": round(formality_score, 3),
59
+ **result
60
+ })
61
+
62
+ if __name__ == "__main__":
63
+ app.run(host="0.0.0.0", port=7860)