SamanthaStorm commited on
Commit
f7d2bee
·
verified ·
1 Parent(s): ba85ad6

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +18 -337
models.py CHANGED
@@ -116,7 +116,7 @@ class ModelManager:
116
  try:
117
  logger.info("Loading custom intent model with MultiLabelIntentClassifier")
118
 
119
- # Create the custom model architecture (referencing the global class)
120
  intent_model = MultiLabelIntentClassifier("distilbert-base-uncased", 6)
121
 
122
  # Download the model file from HuggingFace
@@ -174,43 +174,27 @@ class ModelManager:
174
  try:
175
  logger.info(f"Trying fallback path for {name}: {fallback_path}")
176
 
177
- # Special handling for intent model fallback too
178
  if name == "intent":
 
 
 
179
  try:
180
- try:
181
- from transformers import MultiLabelIntentClassifier
182
- except ImportError:
183
- self.models[name] = AutoModelForSequenceClassification.from_pretrained(
184
- fallback_path,
185
- local_files_only=False,
186
- trust_remote_code=True
187
- ).to(self.device)
188
- else:
189
- self.models[name] = MultiLabelIntentClassifier.from_pretrained(
190
- fallback_path,
191
- local_files_only=False,
192
- trust_remote_code=True
193
- ).to(self.device)
194
-
195
- self.tokenizers[name] = AutoTokenizer.from_pretrained(
196
- fallback_path,
197
- use_fast=False,
198
- local_files_only=False,
199
- trust_remote_code=True
200
  )
201
- except Exception:
202
- self.models[name] = AutoModelForSequenceClassification.from_pretrained(
203
- fallback_path,
204
- local_files_only=False,
205
- trust_remote_code=True
206
- ).to(self.device)
207
 
208
- self.tokenizers[name] = AutoTokenizer.from_pretrained(
209
- fallback_path,
210
- use_fast=False,
211
- local_files_only=False,
212
- trust_remote_code=True
213
- )
 
 
 
 
214
  else:
215
  self.models[name] = AutoModelForSequenceClassification.from_pretrained(
216
  fallback_path,
@@ -236,306 +220,3 @@ class ModelManager:
236
  logger.error(f"Error loading {name} model from fallback path: {e}")
237
 
238
  return False
239
-
240
- def _load_emotion_pipeline(self, max_retries=3):
241
- """Load emotion pipeline with retry logic"""
242
- for attempt in range(max_retries):
243
- try:
244
- logger.info(f"Loading emotion pipeline (attempt {attempt+1}/{max_retries})")
245
-
246
- self.emotion_pipeline = pipeline(
247
- "text-classification",
248
- model="j-hartmann/emotion-english-distilroberta-base",
249
- return_all_scores=True,
250
- top_k=None,
251
- truncation=True,
252
- device=0 if torch.cuda.is_available() else -1
253
- )
254
-
255
- logger.info("Emotion pipeline loaded successfully")
256
- return True
257
-
258
- except Exception as e:
259
- logger.error(f"Error loading emotion pipeline (attempt {attempt+1}): {e}")
260
- time.sleep(2) # Wait before retry
261
-
262
- logger.error("Failed to load emotion pipeline after all retries")
263
- return False
264
-
265
- def predict_fallacy(self, text):
266
- """Predict logical fallacy using FallacyFinder model"""
267
- if not text.strip():
268
- return "No Fallacy", 0.0
269
-
270
- try:
271
- inputs = self._prepare_inputs("fallacy", text)
272
-
273
- with torch.no_grad():
274
- outputs = self.models["fallacy"](**inputs)
275
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
276
- predicted_class_id = predictions.argmax().item()
277
- confidence = predictions.max().item()
278
-
279
- # Get the label from model config or use fallback labels
280
- if hasattr(self.models["fallacy"], 'config') and hasattr(self.models["fallacy"].config, 'id2label'):
281
- predicted_label = self.models["fallacy"].config.id2label[predicted_class_id]
282
- else:
283
- # Fallback labels in case config is missing
284
- fallacy_labels = [
285
- "Ad Hominem", "Strawman", "Whataboutism", "Gaslighting",
286
- "False Dichotomy", "Appeal to Emotion", "DARVO", "Moving Goalposts",
287
- "Cherry Picking", "Appeal to Authority", "Slippery Slope",
288
- "Motte and Bailey", "Gish Gallop", "Kafkatrapping", "Sealioning", "No Fallacy"
289
- ]
290
- predicted_label = fallacy_labels[predicted_class_id] if predicted_class_id < len(fallacy_labels) else "No Fallacy"
291
-
292
- return predicted_label, float(confidence)
293
-
294
- except Exception as e:
295
- logger.error(f"Error in predict_fallacy: {e}")
296
- return "No Fallacy", 0.0
297
-
298
- def predict_abuse_patterns(self, text, thresholds):
299
- """Predict abuse patterns with thresholds"""
300
- if not text.strip():
301
- return [], []
302
-
303
- try:
304
- inputs = self._prepare_inputs("abuse_patterns", text)
305
-
306
- with torch.no_grad():
307
- outputs = self.models["abuse_patterns"](**inputs)
308
-
309
- # Get sigmoid scores for multi-label classification
310
- if self.models["abuse_patterns"].is_multilabel:
311
- raw_scores = torch.sigmoid(outputs.logits.squeeze(0)).cpu().numpy()
312
- else:
313
- # Fallback for non-multilabel model
314
- raw_scores = torch.softmax(outputs.logits.squeeze(0), dim=0).cpu().numpy()
315
-
316
- # Get labels
317
- labels = self.get_abuse_pattern_labels()
318
-
319
- # Apply thresholds and return
320
- predictions = list(zip(labels, raw_scores))
321
- matched_scores = []
322
- threshold_labels = []
323
-
324
- for label, score in predictions:
325
- if score > thresholds.get(label, 0.25):
326
- threshold_labels.append(label)
327
- weight = self.get_pattern_weight(label)
328
- matched_scores.append((label, float(score), weight))
329
-
330
- return threshold_labels, matched_scores
331
-
332
- except Exception as e:
333
- logger.error(f"Error in predict_abuse_patterns: {e}")
334
- return [], []
335
-
336
- def predict_sentiment(self, text):
337
- """Predict sentiment (supportive vs undermining)"""
338
- if not text.strip():
339
- return "neutral", 0.5
340
-
341
- try:
342
- inputs = self._prepare_inputs("sentiment", text)
343
-
344
- with torch.no_grad():
345
- outputs = self.models["sentiment"](**inputs)
346
- logits = outputs.logits[0]
347
- probs = softmax(logits, dim=-1).cpu().numpy()
348
-
349
- # Get sentiment labels
350
- labels = ["supportive", "undermining"]
351
- sentiment = labels[int(probs.argmax())]
352
- confidence = float(probs.max())
353
-
354
- return sentiment, confidence
355
-
356
- except Exception as e:
357
- logger.error(f"Error in predict_sentiment: {e}")
358
- return "neutral", 0.5
359
-
360
- def predict_darvo(self, text):
361
- """Predict DARVO score"""
362
- if not text.strip():
363
- return 0.0
364
-
365
- try:
366
- inputs = self._prepare_inputs("darvo", text)
367
-
368
- with torch.no_grad():
369
- logits = self.models["darvo"](**inputs).logits
370
- if self.models["darvo"].is_regression:
371
- score = float(sigmoid(logits.cpu()).item())
372
- else:
373
- # Fallback for classification model
374
- probs = softmax(logits, dim=-1).cpu().numpy()[0]
375
- score = float(probs[1]) # Assume second class is DARVO
376
-
377
- return score
378
-
379
- except Exception as e:
380
- logger.error(f"Error in predict_darvo: {e}")
381
- return 0.0
382
-
383
- def predict_boundary_health(self, text):
384
- """Predict boundary health (1 for healthy, 0 for unhealthy)"""
385
- if not text.strip():
386
- return 0
387
-
388
- try:
389
- inputs = self._prepare_inputs("boundary", text)
390
-
391
- with torch.no_grad():
392
- outputs = self.models["boundary"](**inputs)
393
- predictions = softmax(outputs.logits, dim=-1)
394
- predicted_class = torch.argmax(predictions, dim=-1).item()
395
-
396
- return predicted_class
397
-
398
- except Exception as e:
399
- logger.error(f"Error in predict_boundary_health: {e}")
400
- return 0
401
-
402
- def predict_intent(self, text):
403
- """Predict intent using custom multilabel classification model"""
404
- if not text.strip():
405
- return "neutral", 0.5
406
-
407
- # Check if intent model is available
408
- if "intent" not in self.models:
409
- logger.warning("Intent model not available, returning neutral intent")
410
- return "neutral", 0.5
411
-
412
- try:
413
- self.models["intent"].eval()
414
-
415
- inputs = self.tokenizers["intent"](text, return_tensors="pt", truncation=True, padding=True, max_length=128)
416
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
417
-
418
- with torch.no_grad():
419
- outputs = self.models["intent"](inputs['input_ids'], inputs['attention_mask'])
420
- probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
421
-
422
- # Intent categories (same as your working app)
423
- intent_categories = ['trolling', 'dismissive', 'manipulative', 'emotionally_reactive', 'constructive', 'unclear']
424
- intent_thresholds = {
425
- 'trolling': 0.70,
426
- 'manipulative': 0.65,
427
- 'dismissive': 0.60,
428
- 'constructive': 0.60,
429
- 'emotionally_reactive': 0.55,
430
- 'unclear': 0.50
431
- }
432
-
433
- # Get predictions above threshold
434
- detected_intents = {}
435
- for i, category in enumerate(intent_categories):
436
- prob = probabilities[i]
437
- threshold = intent_thresholds[category]
438
- if prob > threshold:
439
- detected_intents[category] = prob
440
-
441
- # If no intents above threshold, use the highest one if it's reasonable
442
- if not detected_intents:
443
- max_idx = probabilities.argmax()
444
- max_category = intent_categories[max_idx]
445
- max_prob = probabilities[max_idx]
446
- if max_prob > 0.3: # Minimum confidence
447
- detected_intents[max_category] = max_prob
448
-
449
- # Return primary intent for compatibility with existing analyzer
450
- if detected_intents:
451
- primary_intent = max(detected_intents.items(), key=lambda x: x[1])
452
- return primary_intent[0], primary_intent[1]
453
- else:
454
- return "neutral", 0.5
455
-
456
- except Exception as e:
457
- logger.error(f"Error in predict_intent: {e}")
458
- return "neutral", 0.5
459
-
460
- def get_emotion_profile(self, text):
461
- """Get emotion profile from text"""
462
- if not text.strip() or not self.emotion_pipeline:
463
- return {
464
- "sadness": 0.0,
465
- "joy": 0.0,
466
- "neutral": 0.0,
467
- "disgust": 0.0,
468
- "anger": 0.0,
469
- "fear": 0.0
470
- }
471
-
472
- try:
473
- emotions = self.emotion_pipeline(text)
474
- if isinstance(emotions, list) and isinstance(emotions[0], list):
475
- emotion_scores = emotions[0]
476
- return {e['label'].lower(): round(e['score'], 3) for e in emotion_scores}
477
- return {}
478
- except Exception as e:
479
- logger.error(f"Error in get_emotion_profile: {e}")
480
- return {
481
- "sadness": 0.0,
482
- "joy": 0.0,
483
- "neutral": 0.0,
484
- "disgust": 0.0,
485
- "anger": 0.0,
486
- "fear": 0.0
487
- }
488
-
489
- def _prepare_inputs(self, model_name, text):
490
- """Prepare inputs for the model"""
491
- try:
492
- # Set max_length for fallacy model to match training
493
- max_length = 512 if model_name == "fallacy" else None
494
-
495
- inputs = self.tokenizers[model_name](
496
- text,
497
- return_tensors="pt",
498
- truncation=True,
499
- padding=True,
500
- max_length=max_length
501
- )
502
- return {k: v.to(self.device) for k, v in inputs.items()}
503
- except Exception as e:
504
- logger.error(f"Error preparing inputs for {model_name}: {e}")
505
- # Return dummy inputs
506
- return {
507
- "input_ids": torch.ones((1, 10), dtype=torch.long).to(self.device),
508
- "attention_mask": torch.ones((1, 10), dtype=torch.long).to(self.device)
509
- }
510
-
511
- def get_abuse_pattern_labels(self):
512
- """Get abuse pattern labels"""
513
- return [
514
- "recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",
515
- "blame shifting", "nonabusive", "projection", "insults",
516
- "contradictory statements", "obscure language",
517
- "veiled threats", "stalking language", "false concern",
518
- "false equivalence", "future faking"
519
- ]
520
-
521
- def get_pattern_weight(self, label):
522
- """Get pattern weight for scoring"""
523
- weights = {
524
- "recovery phase": 0.7,
525
- "control": 1.4,
526
- "gaslighting": 1.3,
527
- "guilt tripping": 1.2,
528
- "dismissiveness": 0.9,
529
- "blame shifting": 1.0,
530
- "projection": 0.5,
531
- "insults": 1.4,
532
- "contradictory statements": 1.0,
533
- "obscure language": 0.9,
534
- "nonabusive": 0.0,
535
- "veiled threats": 1.6,
536
- "stalking language": 1.8,
537
- "false concern": 1.1,
538
- "false equivalence": 1.3,
539
- "future faking": 0.8
540
- }
541
- return weights.get(label, 1.0)
 
116
  try:
117
  logger.info("Loading custom intent model with MultiLabelIntentClassifier")
118
 
119
+ # Create the custom model architecture using the class defined at module level
120
  intent_model = MultiLabelIntentClassifier("distilbert-base-uncased", 6)
121
 
122
  # Download the model file from HuggingFace
 
174
  try:
175
  logger.info(f"Trying fallback path for {name}: {fallback_path}")
176
 
177
+ # Special handling for intent model fallback
178
  if name == "intent":
179
+ # Use the locally defined MultiLabelIntentClassifier class
180
+ custom_model = MultiLabelIntentClassifier("distilbert-base-uncased", 6)
181
+
182
  try:
183
+ model_path = hf_hub_download(
184
+ repo_id=fallback_path,
185
+ filename="pytorch_model.bin"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
 
 
 
 
 
 
187
 
188
+ state_dict = torch.load(model_path, map_location='cpu')
189
+ custom_model.load_state_dict(state_dict)
190
+ self.models[name] = custom_model.to(self.device)
191
+
192
+ # Use distilbert tokenizer for intent model
193
+ self.tokenizers[name] = AutoTokenizer.from_pretrained("distilbert-base-uncased")
194
+
195
+ except Exception as fallback_error:
196
+ logger.error(f"Failed to load intent model from fallback: {fallback_error}")
197
+ raise fallback_error
198
  else:
199
  self.models[name] = AutoModelForSequenceClassification.from_pretrained(
200
  fallback_path,
 
220
  logger.error(f"Error loading {name} model from fallback path: {e}")
221
 
222
  return False