aideveloper24 commited on
Commit
97b6442
·
verified ·
1 Parent(s): 07ee2d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, make_response
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+ import os
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Global variables to store model and tokenizer
9
+ global_tokenizer = None
10
+ global_model = None
11
+
12
+ def load_model():
13
+ """Load the model and tokenizer"""
14
+ global global_tokenizer, global_model
15
+ try:
16
+ print("Loading model and tokenizer...")
17
+ MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
18
+ global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
19
+ global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
20
+ global_model.eval()
21
+ print("Model loaded successfully!")
22
+ return True
23
+ except Exception as e:
24
+ print(f"Error loading model: {str(e)}")
25
+ return False
26
+
27
+ # Load model at startup
28
+ load_model()
29
+
30
+ @app.route('/', methods=['GET'])
31
+ def home():
32
+ """Home endpoint to check if API is running"""
33
+ response = {
34
+ 'status': 'API is running',
35
+ 'model_status': 'loaded' if global_model is not None else 'not loaded',
36
+ 'usage': {
37
+ 'endpoint': '/classify',
38
+ 'method': 'POST',
39
+ 'body': {'subject': 'Your email subject here'}
40
+ }
41
+ }
42
+ return jsonify(response)
43
+
44
+ @app.route('/health', methods=['GET'])
45
+ def health_check():
46
+ """Health check endpoint"""
47
+ if global_model is None or global_tokenizer is None:
48
+ return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
49
+ return jsonify({'status': 'healthy'})
50
+
51
+ @app.route('/classify', methods=['POST'])
52
+ def classify_email():
53
+ """Classify email subject"""
54
+ if global_model is None or global_tokenizer is None:
55
+ return jsonify({'error': 'Model not loaded'}), 503
56
+
57
+ try:
58
+ # Get request data
59
+ data = request.get_json()
60
+
61
+ if not data or 'subject' not in data:
62
+ return jsonify({
63
+ 'error': 'No subject provided. Please send a JSON with "subject" field.'
64
+ }), 400
65
+
66
+ # Get the subject
67
+ subject = data['subject']
68
+
69
+ # Tokenize
70
+ inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
71
+
72
+ # Predict
73
+ with torch.no_grad():
74
+ outputs = global_model(**inputs)
75
+ logits = outputs.logits
76
+
77
+ # Get probabilities
78
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
79
+ predicted_class_id = logits.argmax().item()
80
+ confidence = probabilities[0][predicted_class_id].item()
81
+
82
+ # Map to custom labels
83
+ CUSTOM_LABELS = {
84
+ 0: "Business/Professional",
85
+ 1: "Personal/Casual"
86
+ }
87
+
88
+ result = {
89
+ 'category': CUSTOM_LABELS[predicted_class_id],
90
+ 'confidence': round(confidence, 3),
91
+ 'all_categories': {
92
+ label: round(prob.item(), 3)
93
+ for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
94
+ }
95
+ }
96
+
97
+ return jsonify(result)
98
+
99
+ except Exception as e:
100
+ print(f"Error in classification: {str(e)}")
101
+ return jsonify({'error': str(e)}), 500
102
+
103
+ if __name__ == '__main__':
104
+ # Use port 7860 for Hugging Face Spaces
105
+ port = int(os.environ.get('PORT', 7860))
106
+ app.run(host='0.0.0.0', port=port)