Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from flask import Flask, request, jsonify
|
2 |
import torch
|
3 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
4 |
import os
|
@@ -14,11 +14,8 @@ def load_model():
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
-
# Replace this path with your model's directory
|
18 |
-
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" #
|
19 |
-
|
20 |
-
# If you have a local model path, use the path to your model
|
21 |
-
# model_dir = "/path/to/your/local/model"
|
22 |
|
23 |
# Load tokenizer and model from Hugging Face Hub or a local path
|
24 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
@@ -86,7 +83,7 @@ def classify_email():
|
|
86 |
predicted_class_id = logits.argmax().item()
|
87 |
confidence = probabilities[0][predicted_class_id].item()
|
88 |
|
89 |
-
#
|
90 |
CUSTOM_LABELS = {
|
91 |
0: "Business/Professional",
|
92 |
1: "Personal/Casual"
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
import torch
|
3 |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
|
4 |
import os
|
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
+
# Replace this path with your model's directory if using a custom model
|
18 |
+
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Example model, replace with your own if needed
|
|
|
|
|
|
|
19 |
|
20 |
# Load tokenizer and model from Hugging Face Hub or a local path
|
21 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
|
|
83 |
predicted_class_id = logits.argmax().item()
|
84 |
confidence = probabilities[0][predicted_class_id].item()
|
85 |
|
86 |
+
# Define custom categories (Modify this as needed)
|
87 |
CUSTOM_LABELS = {
|
88 |
0: "Business/Professional",
|
89 |
1: "Personal/Casual"
|