File size: 2,690 Bytes
98944a8 cb536ff 97b6442 cb536ff 97b6442 bbba3bf cb536ff 97b6442 bbba3bf 97b6442 cb536ff 97b6442 98944a8 97b6442 cb536ff 97b6442 1a0232d 97b6442 1a0232d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
app = Flask(__name__)
# Load the model and tokenizer directly
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
model.eval() # Set the model to evaluation mode
@app.route('/', methods=['GET'])
def home():
"""Home endpoint to check if API is running"""
response = {
'status': 'API is running',
'usage': {
'endpoint': '/classify',
'method': 'POST',
'body': {'subject': 'Your email subject here'}
}
}
return jsonify(response)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({'status': 'healthy'})
@app.route('/classify', methods=['POST'])
def classify_email():
"""Classify email subject"""
try:
# Get request data
data = request.get_json()
if not data or 'subject' not in data:
return jsonify({
'error': 'No subject provided. Please send a JSON with "subject" field.'
}), 400
# Get the subject
subject = data['subject']
# Tokenize
inputs = tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class_id = logits.argmax().item()
confidence = probabilities[0][predicted_class_id].item()
# Define custom categories (Modify this as needed)
CUSTOM_LABELS = {
0: "Negative",
1: "Positive"
}
result = {
'category': CUSTOM_LABELS[predicted_class_id],
'confidence': round(confidence, 3),
'all_categories': {
label: round(prob.item(), 3)
for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
}
}
return jsonify(result)
except Exception as e:
print(f"Error in classification: {str(e)}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Use port 7860 for Hugging Face Spaces or any other port for local testing
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)
|