lang_id_flask / app.py
MSP RAJA
updated
f784d15
raw
history blame
2.48 kB
import logging
from flask import Flask, request, jsonify
import os
from wtforms import Form, StringField
from wtforms.validators import DataRequired
from config import model_ckpt, pipe, labels
app = Flask(__name__)
# # configure logging
# logging.basicConfig(
# filename='app.log',
# level=logging.INFO,
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
# )
# logger = logging.getLogger(__name__)
class PredictForm(Form):
text = StringField('text', [DataRequired()])
def predict(text: str) -> dict:
"""
Compute predictions for text.
:param text: str : The text to be analyzed.
:return: dict : A dictionary of predicted language and its score
"""
try:
preds = pipe(text, return_all_scores=True, truncation=True, max_length=128)
if preds:
pred = preds[0]
pred = sorted(pred, key=lambda x: x['score'], reverse=True)
return {labels.get(p["label"],p["label"]): float(p["score"]) for p in pred[:1]}
else:
return {}
except Exception as e:
logger.error("Error processing request: %s", str(e))
return {'error': str(e)}, 500
@app.route('/language', methods=['POST'])
def predict_language():
"""
A Language Prediction API which accepts 'text' as input and return the language of text along with score
---
parameters:
- in: body
name: text
schema:
type: string
required: true
description: The text to be analyzed
responses:
200:
description: A JSON object containing the language and its score
schema:
type: object
400:
description: Invalid request
500:
description: Internal server error
"""
# form = PredictForm(request.form)
# if form.validate():
text = request.json['text']
if not text:
return jsonify({'error': 'Empty text provided'}), 400
result = predict(text)
if result:
return jsonify(result)
else:
return jsonify({'error': 'No predictions found'}), 400
# else:
# return jsonify({'error': 'Invalid input provided'}), 400
if __name__ == '__main__':
log_file = 'app.log'
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.info("Running the app...")
app.run()