Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,11 @@ def load_model():
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
-
# Replace this path with your model's directory or Hugging Face model
|
18 |
-
MODEL_NAME = "
|
|
|
|
|
|
|
19 |
|
20 |
# Load tokenizer and model from Hugging Face Hub or a local path
|
21 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
@@ -70,10 +73,10 @@ def classify_email():
|
|
70 |
# Get the subject
|
71 |
subject = data['subject']
|
72 |
|
73 |
-
# Tokenize
|
74 |
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
|
75 |
|
76 |
-
# Predict
|
77 |
with torch.no_grad():
|
78 |
outputs = global_model(**inputs)
|
79 |
logits = outputs.logits
|
@@ -89,7 +92,6 @@ def classify_email():
|
|
89 |
1: "Personal/Casual"
|
90 |
}
|
91 |
|
92 |
-
# Create the response
|
93 |
result = {
|
94 |
'category': CUSTOM_LABELS[predicted_class_id],
|
95 |
'confidence': round(confidence, 3),
|
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
+
# Replace this path with your model's directory or Hugging Face model
|
18 |
+
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Test with a known model for now
|
19 |
+
|
20 |
+
# If you have a local model path, use the path to your model
|
21 |
+
# model_dir = "/path/to/your/local/model"
|
22 |
|
23 |
# Load tokenizer and model from Hugging Face Hub or a local path
|
24 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
|
|
73 |
# Get the subject
|
74 |
subject = data['subject']
|
75 |
|
76 |
+
# Tokenize
|
77 |
inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
|
78 |
|
79 |
+
# Predict
|
80 |
with torch.no_grad():
|
81 |
outputs = global_model(**inputs)
|
82 |
logits = outputs.logits
|
|
|
92 |
1: "Personal/Casual"
|
93 |
}
|
94 |
|
|
|
95 |
result = {
|
96 |
'category': CUSTOM_LABELS[predicted_class_id],
|
97 |
'confidence': round(confidence, 3),
|