Email_Model / app.py
aideveloper24's picture
Create app.py
efa6633 verified
raw
history blame
3.41 kB
from flask import Flask, request, jsonify, make_response
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import os
app = Flask(__name__)
# Global variables to store model and tokenizer
global_tokenizer = None
global_model = None
def load_model():
"""Load the model and tokenizer"""
global global_tokenizer, global_model
try:
print("Loading model and tokenizer...")
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
global_model.eval()
print("Model loaded successfully!")
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
return False
# Load model at startup
load_model()
@app.route('/', methods=['GET'])
def home():
"""Home endpoint to check if API is running"""
response = {
'status': 'API is running',
'model_status': 'loaded' if global_model is not None else 'not loaded',
'usage': {
'endpoint': '/classify',
'method': 'POST',
'body': {'subject': 'Your email subject here'}
}
}
return jsonify(response)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
if global_model is None or global_tokenizer is None:
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
return jsonify({'status': 'healthy'})
@app.route('/classify', methods=['POST'])
def classify_email():
"""Classify email subject"""
if global_model is None or global_tokenizer is None:
return jsonify({'error': 'Model not loaded'}), 503
try:
# Get request data
data = request.get_json()
if not data or 'subject' not in data:
return jsonify({
'error': 'No subject provided. Please send a JSON with "subject" field.'
}), 400
# Get the subject
subject = data['subject']
# Tokenize
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
# Predict
with torch.no_grad():
outputs = global_model(**inputs)
logits = outputs.logits
# Get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
confidence = probabilities[0][predicted_class_id].item()
# Map to custom labels
CUSTOM_LABELS = {
0: "Business/Professional",
1: "Personal/Casual"
}
result = {
'category': CUSTOM_LABELS[predicted_class_id],
'confidence': round(confidence, 3),
'all_categories': {
label: round(prob.item(), 3)
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
}
}
return jsonify(result)
except Exception as e:
print(f"Error in classification: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Use port 7860 for Hugging Face Spaces
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)