SamanthaStorm commited on
Commit
b63e8f7
·
verified ·
1 Parent(s): c3eb637

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +290 -135
models.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import logging
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
 
4
  from torch.nn.functional import sigmoid, softmax
5
 
6
  # Set up logging
@@ -18,169 +20,314 @@ class ModelManager:
18
  self.tokenizers = {}
19
 
20
  def load_models(self):
21
- """Load all required models"""
22
- # Core abuse pattern detection model
23
- self._load_model(
24
- "abuse_patterns",
25
- "SamanthaStorm/tether-multilabel-v6",
26
- is_multilabel=True
27
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Sentiment model
30
- self._load_model(
31
- "sentiment",
32
- "SamanthaStorm/tether-sentiment-v3",
33
- is_multilabel=False
34
- )
 
 
 
 
 
 
 
35
 
36
- # DARVO model
37
- self._load_model(
38
- "darvo",
39
- "SamanthaStorm/tether-darvo-regressor-v1",
40
- is_multilabel=False,
41
- is_regression=True
42
- )
43
 
44
- # Boundary health model
45
- self._load_model(
46
- "boundary",
47
- "SamanthaStorm/healthy-boundary-predictor",
48
- is_multilabel=False
49
- )
50
 
51
- # Intent analyzer model
52
- self._load_model(
53
- "intent",
54
- "SamanthaStorm/intentanalyzer",
55
- is_multilabel=False
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Emotion model
59
- try:
60
- from transformers import pipeline
61
- self.emotion_pipeline = pipeline(
62
- "text-classification",
63
- model="j-hartmann/emotion-english-distilroberta-base",
64
- return_all_scores=True,
65
- top_k=None,
66
- truncation=True,
67
- device=0 if torch.cuda.is_available() else -1
68
- )
69
- logger.info("Emotion pipeline loaded successfully")
70
- except Exception as e:
71
- logger.error(f"Error loading emotion pipeline: {e}")
72
- self.emotion_pipeline = None
73
-
74
- logger.info("All models loaded successfully")
 
 
 
 
 
 
 
 
 
75
 
76
- def _load_model(self, name, model_path, is_multilabel=False, is_regression=False):
77
- """Helper to load a model and its tokenizer"""
78
- try:
79
- logger.info(f"Loading {name} model from {model_path}")
80
- self.models[name] = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
81
- self.tokenizers[name] = AutoTokenizer.from_pretrained(model_path, use_fast=False)
82
-
83
- # Store model metadata
84
- self.models[name].is_multilabel = is_multilabel
85
- self.models[name].is_regression = is_regression
86
-
87
- logger.info(f"{name} model loaded successfully")
88
- except Exception as e:
89
- logger.error(f"Error loading {name} model: {e}")
90
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def predict_abuse_patterns(self, text, thresholds):
93
  """Predict abuse patterns with thresholds"""
94
  if not text.strip():
95
  return [], []
96
 
97
- inputs = self._prepare_inputs("abuse_patterns", text)
98
-
99
- with torch.no_grad():
100
- outputs = self.models["abuse_patterns"](**inputs)
101
-
102
- # Get sigmoid scores for multi-label classification
103
- raw_scores = torch.sigmoid(outputs.logits.squeeze(0)).cpu().numpy()
104
-
105
- # Get labels
106
- labels = self.get_abuse_pattern_labels()
107
-
108
- # Apply thresholds and return
109
- predictions = list(zip(labels, raw_scores))
110
- matched_scores = []
111
- threshold_labels = []
112
-
113
- for label, score in predictions:
114
- if score > thresholds.get(label, 0.25):
115
- threshold_labels.append(label)
116
- weight = self.get_pattern_weight(label)
117
- matched_scores.append((label, float(score), weight))
118
-
119
- return threshold_labels, matched_scores
 
 
 
 
 
 
 
 
 
120
 
121
  def predict_sentiment(self, text):
122
  """Predict sentiment (supportive vs undermining)"""
123
  if not text.strip():
124
  return "neutral", 0.5
125
 
126
- inputs = self._prepare_inputs("sentiment", text)
127
-
128
- with torch.no_grad():
129
- outputs = self.models["sentiment"](**inputs)
130
- logits = outputs.logits[0]
131
- probs = softmax(logits, dim=-1).cpu().numpy()
132
-
133
- # Get sentiment labels
134
- labels = ["supportive", "undermining"]
135
- sentiment = labels[int(probs.argmax())]
136
- confidence = float(probs.max())
137
-
138
- return sentiment, confidence
 
 
 
 
 
139
 
140
  def predict_darvo(self, text):
141
  """Predict DARVO score"""
142
  if not text.strip():
143
  return 0.0
144
 
145
- inputs = self._prepare_inputs("darvo", text)
146
-
147
- with torch.no_grad():
148
- logits = self.models["darvo"](**inputs).logits
149
- score = float(sigmoid(logits.cpu()).item())
150
-
151
- return score
 
 
 
 
 
 
 
 
 
 
152
 
153
  def predict_boundary_health(self, text):
154
  """Predict boundary health (1 for healthy, 0 for unhealthy)"""
155
  if not text.strip():
156
  return 0
157
 
158
- inputs = self._prepare_inputs("boundary", text)
159
-
160
- with torch.no_grad():
161
- outputs = self.models["boundary"](**inputs)
162
- predictions = softmax(outputs.logits, dim=-1)
163
- predicted_class = torch.argmax(predictions, dim=-1).item()
164
-
165
- return predicted_class
 
 
 
 
 
166
 
167
  def predict_intent(self, text):
168
  """Predict intent"""
169
  if not text.strip():
170
  return "neutral", 0.5
171
 
172
- inputs = self._prepare_inputs("intent", text)
173
-
174
- with torch.no_grad():
175
- outputs = self.models["intent"](**inputs)
176
- probs = softmax(outputs.logits, dim=-1).cpu().numpy()[0]
177
-
178
- # Get intent labels (adjust based on actual model outputs)
179
- labels = ["neutral", "manipulative", "supportive", "controlling"]
180
- intent = labels[int(probs.argmax())]
181
- confidence = float(probs.max())
182
-
183
- return intent, confidence
 
 
 
 
 
184
 
185
  def get_emotion_profile(self, text):
186
  """Get emotion profile from text"""
@@ -213,13 +360,21 @@ class ModelManager:
213
 
214
  def _prepare_inputs(self, model_name, text):
215
  """Prepare inputs for the model"""
216
- inputs = self.tokenizers[model_name](
217
- text,
218
- return_tensors="pt",
219
- truncation=True,
220
- padding=True
221
- )
222
- return {k: v.to(self.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
223
 
224
  def get_abuse_pattern_labels(self):
225
  """Get abuse pattern labels"""
 
1
  import torch
2
  import logging
3
+ import os
4
+ import time
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
6
  from torch.nn.functional import sigmoid, softmax
7
 
8
  # Set up logging
 
20
  self.tokenizers = {}
21
 
22
  def load_models(self):
23
+ """Load all required models with retry logic and fallbacks"""
24
+ # Define models to load with fallbacks
25
+ model_configs = [
26
+ {
27
+ "name": "abuse_patterns",
28
+ "primary_path": "SamanthaStorm/tether-multilabel-v6",
29
+ "fallback_path": "SamanthaStorm/tether-multilabel-v5", # Fallback to older version
30
+ "is_multilabel": True
31
+ },
32
+ {
33
+ "name": "sentiment",
34
+ "primary_path": "SamanthaStorm/tether-sentiment-v3",
35
+ "fallback_path": "SamanthaStorm/tether-sentiment-v2",
36
+ "is_multilabel": False
37
+ },
38
+ {
39
+ "name": "darvo",
40
+ "primary_path": "SamanthaStorm/tether-darvo-regressor-v1",
41
+ "fallback_path": None, # No fallback, will use dummy model if fails
42
+ "is_multilabel": False,
43
+ "is_regression": True
44
+ },
45
+ {
46
+ "name": "boundary",
47
+ "primary_path": "SamanthaStorm/healthy-boundary-predictor",
48
+ "fallback_path": None, # No fallback, will use dummy model if fails
49
+ "is_multilabel": False
50
+ },
51
+ {
52
+ "name": "intent",
53
+ "primary_path": "SamanthaStorm/intentanalyzer",
54
+ "fallback_path": None, # No fallback, will use dummy model if fails
55
+ "is_multilabel": False
56
+ }
57
+ ]
58
 
59
+ # Load each model with retry logic
60
+ for config in model_configs:
61
+ success = self._load_model_with_retry(
62
+ config["name"],
63
+ config["primary_path"],
64
+ config["fallback_path"],
65
+ is_multilabel=config.get("is_multilabel", False),
66
+ is_regression=config.get("is_regression", False)
67
+ )
68
+
69
+ if not success:
70
+ logger.warning(f"Creating dummy model for {config['name']}")
71
+ self._create_dummy_model(config["name"], config.get("is_multilabel", False))
72
 
73
+ # Load emotion pipeline separately with retry
74
+ self._load_emotion_pipeline()
 
 
 
 
 
75
 
76
+ logger.info("Model loading completed")
 
 
 
 
 
77
 
78
+ def _load_model_with_retry(self, name, primary_path, fallback_path=None, is_multilabel=False, is_regression=False, max_retries=3):
79
+ """Load a model with retry logic and fallback option"""
80
+ for attempt in range(max_retries):
81
+ try:
82
+ logger.info(f"Loading {name} model from {primary_path} (attempt {attempt+1}/{max_retries})")
83
+
84
+ # Try to load from primary path
85
+ self.models[name] = AutoModelForSequenceClassification.from_pretrained(
86
+ primary_path,
87
+ local_files_only=False,
88
+ trust_remote_code=False
89
+ ).to(self.device)
90
+
91
+ self.tokenizers[name] = AutoTokenizer.from_pretrained(
92
+ primary_path,
93
+ use_fast=False,
94
+ local_files_only=False,
95
+ trust_remote_code=False
96
+ )
97
+
98
+ # Store model metadata
99
+ self.models[name].is_multilabel = is_multilabel
100
+ self.models[name].is_regression = is_regression
101
+
102
+ logger.info(f"{name} model loaded successfully")
103
+ return True
104
+
105
+ except Exception as e:
106
+ logger.error(f"Error loading {name} model (attempt {attempt+1}): {e}")
107
+ time.sleep(2) # Wait before retry
108
 
109
+ # If primary path failed, try fallback if available
110
+ if fallback_path:
111
+ try:
112
+ logger.info(f"Trying fallback path for {name}: {fallback_path}")
113
+ self.models[name] = AutoModelForSequenceClassification.from_pretrained(
114
+ fallback_path,
115
+ local_files_only=False,
116
+ trust_remote_code=False
117
+ ).to(self.device)
118
+
119
+ self.tokenizers[name] = AutoTokenizer.from_pretrained(
120
+ fallback_path,
121
+ use_fast=False,
122
+ local_files_only=False,
123
+ trust_remote_code=False
124
+ )
125
+
126
+ # Store model metadata
127
+ self.models[name].is_multilabel = is_multilabel
128
+ self.models[name].is_regression = is_regression
129
+
130
+ logger.info(f"{name} model loaded from fallback path")
131
+ return True
132
+
133
+ except Exception as e:
134
+ logger.error(f"Error loading {name} model from fallback path: {e}")
135
 
136
+ return False
137
+
138
+ def _load_emotion_pipeline(self, max_retries=3):
139
+ """Load emotion pipeline with retry logic"""
140
+ for attempt in range(max_retries):
141
+ try:
142
+ from transformers import pipeline
143
+ logger.info(f"Loading emotion pipeline (attempt {attempt+1}/{max_retries})")
144
+
145
+ self.emotion_pipeline = pipeline(
146
+ "text-classification",
147
+ model="j-hartmann/emotion-english-distilroberta-base",
148
+ return_all_scores=True,
149
+ top_k=None,
150
+ truncation=True,
151
+ device=0 if torch.cuda.is_available() else -1
152
+ )
153
+
154
+ logger.info("Emotion pipeline loaded successfully")
155
+ return True
156
+
157
+ except Exception as e:
158
+ logger.error(f"Error loading emotion pipeline (attempt {attempt+1}): {e}")
159
+ time.sleep(2) # Wait before retry
160
+
161
+ logger.warning("Failed to load emotion pipeline, using dummy")
162
+ self.emotion_pipeline = None
163
+ return False
164
+
165
+ def _create_dummy_model(self, name, is_multilabel=False):
166
+ """Create a dummy model that returns neutral predictions"""
167
+ class DummyModel:
168
+ def __init__(self, is_multilabel=False):
169
+ self.is_multilabel = is_multilabel
170
+ self.is_regression = False
171
+
172
+ def __call__(self, **kwargs):
173
+ class DummyOutput:
174
+ def __init__(self, is_multilabel):
175
+ if is_multilabel:
176
+ # For multilabel, create logits for each class (16 classes)
177
+ self.logits = torch.zeros((1, 16))
178
+ else:
179
+ # For classification, create logits for 2 classes
180
+ self.logits = torch.zeros((1, 2))
181
+ # Slightly bias toward first class
182
+ self.logits[0, 0] = 0.2
183
+
184
+ return DummyOutput(self.is_multilabel)
185
+
186
+ def eval(self):
187
+ return self
188
+
189
+ def to(self, device):
190
+ return self
191
+
192
+ # Create dummy model and tokenizer
193
+ self.models[name] = DummyModel(is_multilabel)
194
+
195
+ class DummyTokenizer:
196
+ def __call__(self, text, **kwargs):
197
+ return {
198
+ "input_ids": torch.ones((1, 10), dtype=torch.long),
199
+ "attention_mask": torch.ones((1, 10), dtype=torch.long)
200
+ }
201
+
202
+ self.tokenizers[name] = DummyTokenizer()
203
+ logger.warning(f"Created dummy model for {name}")
204
 
205
  def predict_abuse_patterns(self, text, thresholds):
206
  """Predict abuse patterns with thresholds"""
207
  if not text.strip():
208
  return [], []
209
 
210
+ try:
211
+ inputs = self._prepare_inputs("abuse_patterns", text)
212
+
213
+ with torch.no_grad():
214
+ outputs = self.models["abuse_patterns"](**inputs)
215
+
216
+ # Get sigmoid scores for multi-label classification
217
+ if self.models["abuse_patterns"].is_multilabel:
218
+ raw_scores = torch.sigmoid(outputs.logits.squeeze(0)).cpu().numpy()
219
+ else:
220
+ # Fallback for non-multilabel model
221
+ raw_scores = torch.softmax(outputs.logits.squeeze(0), dim=0).cpu().numpy()
222
+
223
+ # Get labels
224
+ labels = self.get_abuse_pattern_labels()
225
+
226
+ # Apply thresholds and return
227
+ predictions = list(zip(labels, raw_scores))
228
+ matched_scores = []
229
+ threshold_labels = []
230
+
231
+ for label, score in predictions:
232
+ if score > thresholds.get(label, 0.25):
233
+ threshold_labels.append(label)
234
+ weight = self.get_pattern_weight(label)
235
+ matched_scores.append((label, float(score), weight))
236
+
237
+ return threshold_labels, matched_scores
238
+
239
+ except Exception as e:
240
+ logger.error(f"Error in predict_abuse_patterns: {e}")
241
+ return [], []
242
 
243
  def predict_sentiment(self, text):
244
  """Predict sentiment (supportive vs undermining)"""
245
  if not text.strip():
246
  return "neutral", 0.5
247
 
248
+ try:
249
+ inputs = self._prepare_inputs("sentiment", text)
250
+
251
+ with torch.no_grad():
252
+ outputs = self.models["sentiment"](**inputs)
253
+ logits = outputs.logits[0]
254
+ probs = softmax(logits, dim=-1).cpu().numpy()
255
+
256
+ # Get sentiment labels
257
+ labels = ["supportive", "undermining"]
258
+ sentiment = labels[int(probs.argmax())]
259
+ confidence = float(probs.max())
260
+
261
+ return sentiment, confidence
262
+
263
+ except Exception as e:
264
+ logger.error(f"Error in predict_sentiment: {e}")
265
+ return "neutral", 0.5
266
 
267
  def predict_darvo(self, text):
268
  """Predict DARVO score"""
269
  if not text.strip():
270
  return 0.0
271
 
272
+ try:
273
+ inputs = self._prepare_inputs("darvo", text)
274
+
275
+ with torch.no_grad():
276
+ logits = self.models["darvo"](**inputs).logits
277
+ if self.models["darvo"].is_regression:
278
+ score = float(sigmoid(logits.cpu()).item())
279
+ else:
280
+ # Fallback for classification model
281
+ probs = softmax(logits, dim=-1).cpu().numpy()[0]
282
+ score = float(probs[1]) # Assume second class is DARVO
283
+
284
+ return score
285
+
286
+ except Exception as e:
287
+ logger.error(f"Error in predict_darvo: {e}")
288
+ return 0.0
289
 
290
  def predict_boundary_health(self, text):
291
  """Predict boundary health (1 for healthy, 0 for unhealthy)"""
292
  if not text.strip():
293
  return 0
294
 
295
+ try:
296
+ inputs = self._prepare_inputs("boundary", text)
297
+
298
+ with torch.no_grad():
299
+ outputs = self.models["boundary"](**inputs)
300
+ predictions = softmax(outputs.logits, dim=-1)
301
+ predicted_class = torch.argmax(predictions, dim=-1).item()
302
+
303
+ return predicted_class
304
+
305
+ except Exception as e:
306
+ logger.error(f"Error in predict_boundary_health: {e}")
307
+ return 0
308
 
309
  def predict_intent(self, text):
310
  """Predict intent"""
311
  if not text.strip():
312
  return "neutral", 0.5
313
 
314
+ try:
315
+ inputs = self._prepare_inputs("intent", text)
316
+
317
+ with torch.no_grad():
318
+ outputs = self.models["intent"](**inputs)
319
+ probs = softmax(outputs.logits, dim=-1).cpu().numpy()[0]
320
+
321
+ # Get intent labels (adjust based on actual model outputs)
322
+ labels = ["neutral", "manipulative", "supportive", "controlling"]
323
+ intent = labels[int(probs.argmax())]
324
+ confidence = float(probs.max())
325
+
326
+ return intent, confidence
327
+
328
+ except Exception as e:
329
+ logger.error(f"Error in predict_intent: {e}")
330
+ return "neutral", 0.5
331
 
332
  def get_emotion_profile(self, text):
333
  """Get emotion profile from text"""
 
360
 
361
  def _prepare_inputs(self, model_name, text):
362
  """Prepare inputs for the model"""
363
+ try:
364
+ inputs = self.tokenizers[model_name](
365
+ text,
366
+ return_tensors="pt",
367
+ truncation=True,
368
+ padding=True
369
+ )
370
+ return {k: v.to(self.device) for k, v in inputs.items()}
371
+ except Exception as e:
372
+ logger.error(f"Error preparing inputs for {model_name}: {e}")
373
+ # Return dummy inputs
374
+ return {
375
+ "input_ids": torch.ones((1, 10), dtype=torch.long).to(self.device),
376
+ "attention_mask": torch.ones((1, 10), dtype=torch.long).to(self.device)
377
+ }
378
 
379
  def get_abuse_pattern_labels(self):
380
  """Get abuse pattern labels"""