Spaces:
Sleeping
Sleeping
File size: 2,457 Bytes
b3df9a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from flask import Flask, request, jsonify
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import os
app = Flask(__name__)
# Initialize model and tokenizer globally
print("Loading model and tokenizer...")
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
print("Model and tokenizer loaded successfully!")
# Custom labels
CUSTOM_LABELS = {
0: "Business/Professional",
1: "Personal/Casual"
}
def classify_text(text):
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
confidence = probabilities[0][predicted_class_id].item()
return {
'category': CUSTOM_LABELS[predicted_class_id],
'confidence': round(confidence, 3),
'all_categories': {
label: round(prob.item(), 3)
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
}
}
except Exception as e:
print(f"Error in classify_text: {str(e)}")
raise
@app.route('/classify', methods=['POST'])
def classify_email():
try:
data = request.get_json()
if not data or 'subject' not in data:
return jsonify({
'error': 'No subject provided. Please send a JSON with "subject" field.'
}), 400
subject = data['subject']
result = classify_text(subject)
return jsonify(result)
except Exception as e:
print(f"Error in classify_email: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/', methods=['GET'])
def home():
return jsonify({
'status': 'API is running',
'model_name': MODEL_NAME,
'usage': {
'endpoint': '/classify',
'method': 'POST',
'body': {'subject': 'Your email subject here'}
}
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
print(f"Starting server on port {port}...")
app.run(host='0.0.0.0', port=port, debug=True) |