cheesecz's picture
Update app.py
3596833 verified
import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# Initialize
app = Flask(__name__)
# Load model and tokenizer once
MODEL_NAME = "s-nlp/roberta-base-formality-ranker"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
@app.route("/", methods=["GET"])
def home():
return jsonify({"message": "Formality Classifier API is running! Use /predict to classify text."})
@app.route("/predict", methods=["POST"])
def predict_formality():
data = request.get_json()
if not data or "text" not in data:
return jsonify({"error": "Text input is required"}), 400
text = data["text"]
# Tokenize input
encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Predict formality score
with torch.no_grad():
logits = model(**encoding).logits
score = logits.softmax(dim=1)[:, 1].item()
# Convert score to formality classification
formal_percent = round(score * 100)
informal_percent = 100 - formal_percent
classification = f"Your speech is {formal_percent}% formal and {informal_percent}% informal."
return jsonify({
"text": text,
"formality_score": round(score, 3),
"formal_percent": formal_percent,
"informal_percent": informal_percent,
"classification": classification
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)