aideveloper24 commited on
Commit
bbba3bf
·
verified ·
1 Parent(s): 195370c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -14,8 +14,11 @@ def load_model():
14
  global global_tokenizer, global_model
15
  try:
16
  print("Loading model and tokenizer...")
17
- # Replace this path with your model's directory or Hugging Face model name
18
- MODEL_NAME = "aideveloper24/email_classify" # Replace with your custom model name
 
 
 
19
 
20
  # Load tokenizer and model from Hugging Face Hub or a local path
21
  global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
@@ -70,10 +73,10 @@ def classify_email():
70
  # Get the subject
71
  subject = data['subject']
72
 
73
- # Tokenize the subject text
74
  inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
75
 
76
- # Predict the class
77
  with torch.no_grad():
78
  outputs = global_model(**inputs)
79
  logits = outputs.logits
@@ -89,7 +92,6 @@ def classify_email():
89
  1: "Personal/Casual"
90
  }
91
 
92
- # Create the response
93
  result = {
94
  'category': CUSTOM_LABELS[predicted_class_id],
95
  'confidence': round(confidence, 3),
 
14
  global global_tokenizer, global_model
15
  try:
16
  print("Loading model and tokenizer...")
17
+ # Replace this path with your model's directory or Hugging Face model
18
+ MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Test with a known model for now
19
+
20
+ # If you have a local model path, use the path to your model
21
+ # model_dir = "/path/to/your/local/model"
22
 
23
  # Load tokenizer and model from Hugging Face Hub or a local path
24
  global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
 
73
  # Get the subject
74
  subject = data['subject']
75
 
76
+ # Tokenize
77
  inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
78
 
79
+ # Predict
80
  with torch.no_grad():
81
  outputs = global_model(**inputs)
82
  logits = outputs.logits
 
92
  1: "Personal/Casual"
93
  }
94
 
 
95
  result = {
96
  'category': CUSTOM_LABELS[predicted_class_id],
97
  'confidence': round(confidence, 3),