aideveloper24 commited on
Commit
b3df9a0
·
verified ·
1 Parent(s): ca290b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+ import os
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Initialize model and tokenizer globally
9
+ print("Loading model and tokenizer...")
10
+ MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
11
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
12
+ model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
13
+ model.eval()
14
+ print("Model and tokenizer loaded successfully!")
15
+
16
+ # Custom labels
17
+ CUSTOM_LABELS = {
18
+ 0: "Business/Professional",
19
+ 1: "Personal/Casual"
20
+ }
21
+
22
+ def classify_text(text):
23
+ try:
24
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
25
+
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+
30
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
31
+ predicted_class_id = logits.argmax().item()
32
+ confidence = probabilities[0][predicted_class_id].item()
33
+
34
+ return {
35
+ 'category': CUSTOM_LABELS[predicted_class_id],
36
+ 'confidence': round(confidence, 3),
37
+ 'all_categories': {
38
+ label: round(prob.item(), 3)
39
+ for label, prob in zip(CUSTOM_LABELS.values(), probabilities[0])
40
+ }
41
+ }
42
+ except Exception as e:
43
+ print(f"Error in classify_text: {str(e)}")
44
+ raise
45
+
46
+ @app.route('/classify', methods=['POST'])
47
+ def classify_email():
48
+ try:
49
+ data = request.get_json()
50
+
51
+ if not data or 'subject' not in data:
52
+ return jsonify({
53
+ 'error': 'No subject provided. Please send a JSON with "subject" field.'
54
+ }), 400
55
+
56
+ subject = data['subject']
57
+ result = classify_text(subject)
58
+ return jsonify(result)
59
+
60
+ except Exception as e:
61
+ print(f"Error in classify_email: {str(e)}")
62
+ return jsonify({'error': str(e)}), 500
63
+
64
+ @app.route('/', methods=['GET'])
65
+ def home():
66
+ return jsonify({
67
+ 'status': 'API is running',
68
+ 'model_name': MODEL_NAME,
69
+ 'usage': {
70
+ 'endpoint': '/classify',
71
+ 'method': 'POST',
72
+ 'body': {'subject': 'Your email subject here'}
73
+ }
74
+ })
75
+
76
+ if __name__ == '__main__':
77
+ port = int(os.environ.get('PORT', 7860))
78
+ print(f"Starting server on port {port}...")
79
+ app.run(host='0.0.0.0', port=port, debug=True)