email_classify / app.py
aideveloper24's picture
Update app.py
bbba3bf verified
raw
history blame
3.71 kB
from flask import Flask, request, jsonify, make_response
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import os
app = Flask(__name__)
# Global variables to store model and tokenizer
global_tokenizer = None
global_model = None
def load_model():
"""Load the model and tokenizer"""
global global_tokenizer, global_model
try:
print("Loading model and tokenizer...")
# Replace this path with your model's directory or Hugging Face model
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Test with a known model for now
# If you have a local model path, use the path to your model
# model_dir = "/path/to/your/local/model"
# Load tokenizer and model from Hugging Face Hub or a local path
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
global_model.eval()
print("Model loaded successfully!")
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
return False
# Load model at startup
model_loaded = load_model()
@app.route('/', methods=['GET'])
def home():
"""Home endpoint to check if API is running"""
response = {
'status': 'API is running',
'model_status': 'loaded' if model_loaded else 'not loaded',
'usage': {
'endpoint': '/classify',
'method': 'POST',
'body': {'subject': 'Your email subject here'}
}
}
return jsonify(response)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
if not model_loaded:
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
return jsonify({'status': 'healthy'})
@app.route('/classify', methods=['POST'])
def classify_email():
"""Classify email subject"""
if not model_loaded:
return jsonify({'error': 'Model not loaded'}), 503
try:
# Get request data
data = request.get_json()
if not data or 'subject' not in data:
return jsonify({
'error': 'No subject provided. Please send a JSON with "subject" field.'
}), 400
# Get the subject
subject = data['subject']
# Tokenize
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
# Predict
with torch.no_grad():
outputs = global_model(**inputs)
logits = outputs.logits
# Get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
confidence = probabilities[0][predicted_class_id].item()
# Map to custom labels
CUSTOM_LABELS = {
0: "Business/Professional",
1: "Personal/Casual"
}
result = {
'category': CUSTOM_LABELS[predicted_class_id],
'confidence': round(confidence, 3),
'all_categories': {
label: round(prob.item(), 3)
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
}
}
return jsonify(result)
except Exception as e:
print(f"Error in classification: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Use port 7860 for Hugging Face Spaces or any other port for local testing
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)