from flask import Flask, request, jsonify import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import os app = Flask(__name__) # Initialize model and tokenizer globally print("Loading model and tokenizer...") MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME) model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() print("Model and tokenizer loaded successfully!") # Custom labels CUSTOM_LABELS = { 0: "Business/Professional", 1: "Personal/Casual" } def classify_text(text): try: inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=1) predicted_class_id = logits.argmax().item() confidence = probabilities[0][predicted_class_id].item() return { 'category': CUSTOM_LABELS[predicted_class_id], 'confidence': round(confidence, 3), 'all_categories': { label: round(prob.item(), 3) for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0]) } } except Exception as e: print(f"Error in classify_text: {str(e)}") raise @app.route('/classify', methods=['POST']) def classify_email(): try: data = request.get_json() if not data or 'subject' not in data: return jsonify({ 'error': 'No subject provided. Please send a JSON with "subject" field.' }), 400 subject = data['subject'] result = classify_text(subject) return jsonify(result) except Exception as e: print(f"Error in classify_email: {str(e)}") return jsonify({'error': str(e)}), 500 @app.route('/', methods=['GET']) def home(): return jsonify({ 'status': 'API is running', 'model_name': MODEL_NAME, 'usage': { 'endpoint': '/classify', 'method': 'POST', 'body': {'subject': 'Your email subject here'} } }) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) print(f"Starting server on port {port}...") app.run(host='0.0.0.0', port=port, debug=True)