Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, make_response | |
import torch | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import os | |
app = Flask(__name__) | |
# Global variables to store model and tokenizer | |
global_tokenizer = None | |
global_model = None | |
def load_model(): | |
"""Load the model and tokenizer""" | |
global global_tokenizer, global_model | |
try: | |
print("Loading model and tokenizer...") | |
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" | |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME) | |
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME) | |
global_model.eval() | |
print("Model loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
return False | |
# Load model at startup | |
load_model() | |
def home(): | |
"""Home endpoint to check if API is running""" | |
response = { | |
'status': 'API is running', | |
'model_status': 'loaded' if global_model is not None else 'not loaded', | |
'usage': { | |
'endpoint': '/classify', | |
'method': 'POST', | |
'body': {'subject': 'Your email subject here'} | |
} | |
} | |
return jsonify(response) | |
def health_check(): | |
"""Health check endpoint""" | |
if global_model is None or global_tokenizer is None: | |
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503 | |
return jsonify({'status': 'healthy'}) | |
def classify_email(): | |
"""Classify email subject""" | |
if global_model is None or global_tokenizer is None: | |
return jsonify({'error': 'Model not loaded'}), 503 | |
try: | |
# Get request data | |
data = request.get_json() | |
if not data or 'subject' not in data: | |
return jsonify({ | |
'error': 'No subject provided. Please send a JSON with "subject" field.' | |
}), 400 | |
# Get the subject | |
subject = data['subject'] | |
# Tokenize | |
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512) | |
# Predict | |
with torch.no_grad(): | |
outputs = global_model(**inputs) | |
logits = outputs.logits | |
# Get probabilities | |
probabilities = torch.nn.functional.softmax(logits, dim=1) | |
predicted_class_id = logits.argmax().item() | |
confidence = probabilities[0][predicted_class_id].item() | |
# Map to custom labels | |
CUSTOM_LABELS = { | |
0: "Business/Professional", | |
1: "Personal/Casual" | |
} | |
result = { | |
'category': CUSTOM_LABELS[predicted_class_id], | |
'confidence': round(confidence, 3), | |
'all_categories': { | |
label: round(prob.item(), 3) | |
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0]) | |
} | |
} | |
return jsonify(result) | |
except Exception as e: | |
print(f"Error in classification: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
if __name__ == '__main__': | |
# Use port 7860 for Hugging Face Spaces | |
port = int(os.environ.get('PORT', 7860)) | |
app.run(host='0.0.0.0', port=port) |