File size: 3,411 Bytes
efa6633
b3df9a0
 
 
 
 
 
efa6633
 
 
b3df9a0
efa6633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3df9a0
 
efa6633
 
 
 
 
 
 
 
 
 
b3df9a0
efa6633
 
 
 
b3df9a0
efa6633
b3df9a0
 
efa6633
b3df9a0
 
 
 
efa6633
 
 
 
 
 
 
b3df9a0
 
 
 
 
 
 
 
 
 
 
efa6633
b3df9a0
 
 
efa6633
b3df9a0
efa6633
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from flask import Flask, request, jsonify, make_response
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import os

app = Flask(__name__)

# Global variables to store model and tokenizer
global_tokenizer = None
global_model = None

def load_model():
    """Load the model and tokenizer"""
    global global_tokenizer, global_model
    try:
        print("Loading model and tokenizer...")
        MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
        global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
        global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
        global_model.eval()
        print("Model loaded successfully!")
        return True
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return False

# Load model at startup
load_model()

@app.route('/', methods=['GET'])
def home():
    """Home endpoint to check if API is running"""
    response = {
        'status': 'API is running',
        'model_status': 'loaded' if global_model is not None else 'not loaded',
        'usage': {
            'endpoint': '/classify',
            'method': 'POST',
            'body': {'subject': 'Your email subject here'}
        }
    }
    return jsonify(response)

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    if global_model is None or global_tokenizer is None:
        return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
    return jsonify({'status': 'healthy'})

@app.route('/classify', methods=['POST'])
def classify_email():
    """Classify email subject"""
    if global_model is None or global_tokenizer is None:
        return jsonify({'error': 'Model not loaded'}), 503

    try:
        # Get request data
        data = request.get_json()
        
        if not data or 'subject' not in data:
            return jsonify({
                'error': 'No subject provided. Please send a JSON with "subject" field.'
            }), 400
        
        # Get the subject
        subject = data['subject']
        
        # Tokenize
        inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
        
        # Predict
        with torch.no_grad():
            outputs = global_model(**inputs)
            logits = outputs.logits
        
        # Get probabilities
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        predicted_class_id = logits.argmax().item()
        confidence = probabilities[0][predicted_class_id].item()
        
        # Map to custom labels
        CUSTOM_LABELS = {
            0: "Business/Professional",
            1: "Personal/Casual"
        }
        
        result = {
            'category': CUSTOM_LABELS[predicted_class_id],
            'confidence': round(confidence, 3),
            'all_categories': {
                label: round(prob.item(), 3) 
                for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
            }
        }
        
        return jsonify(result)
    
    except Exception as e:
        print(f"Error in classification: {str(e)}")
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    # Use port 7860 for Hugging Face Spaces
    port = int(os.environ.get('PORT', 7860))
    app.run(host='0.0.0.0', port=port)