File size: 1,869 Bytes
13ec7ec
 
 
4d95560
f72b785
 
 
 
4d95560
 
f78e601
4d95560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ec7ec
4d95560
 
 
 
 
 
 
 
13ec7ec
4d95560
 
 
 
 
 
 
 
 
 
f72b785
4d95560
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os

os.environ["TRANSFORMERS_CACHE"] = "/tmp"

from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

app = Flask(__name__)

# Load the model and tokenizer
model_name = "s-nlp/roberta-base-formality-ranker"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Fuzzy classification function
def fuzzy_formality(score, threshold=0.75):
    if score < threshold:
        formal_weight = 0.5 * (score / threshold) ** 2 
    else:
        formal_weight = 1 - 0.5 * ((1 - score) / (1 - threshold)) ** 2 

    informal_weight = 1 - formal_weight
    formal_percent = round(formal_weight * 100)
    informal_percent = round(informal_weight * 100)

    return {
        "formal_percent": formal_percent,
        "informal_percent": informal_percent,
        "classification": f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
    }

@app.route("/predict", methods=["POST"])
def predict_formality():
    text = request.json.get("text")
    if not text:
        return jsonify({"error": "Text input is required"}), 400

    # Tokenize input
    encoding = tokenizer(
        text, add_special_tokens=True, truncation=True, padding="max_length", return_tensors="pt"
    )

    # Get predictions
    with torch.no_grad():
        output = model(**encoding)

    # Extract formality score
    softmax_scores = output.logits.softmax(dim=1)
    formality_score = softmax_scores[:, 1].item()

    # Classify using fuzzy logic
    result = fuzzy_formality(formality_score)

    return jsonify({
        "text": text,
        "formality_score": round(formality_score, 3),
        **result
    })

# Ensure Flask runs on the correct port
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)