SamanthaStorm commited on
Commit
483940b
·
verified ·
1 Parent(s): 9a02869

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +252 -22
models.py CHANGED
@@ -1,24 +1,254 @@
1
- from dataclasses import dataclass
2
- from typing import List
3
- from enum import Enum
 
4
 
5
- @dataclass
6
- class MessageAnalysis:
7
- timestamp: str
8
- message_id: str
9
- text: str
10
- sender: str
11
- abuse_score: float
12
- darvo_score: float
13
- boundary_health: str
14
- detected_patterns: List[str]
15
- emotional_tone: str
16
- risk_level: str
17
 
18
- class RiskTrend(Enum):
19
- ESCALATING = "escalating"
20
- IMPROVING = "improving"
21
- STABLE_HIGH = "stable_high"
22
- STABLE_MODERATE = "stable_moderate"
23
- CYCLICAL = "cyclical"
24
- UNKNOWN = "unknown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from torch.nn.functional import sigmoid, softmax
5
 
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
9
 
10
+ class ModelManager:
11
+ def __init__(self, device=None):
12
+ """Initialize model manager with device detection"""
13
+ self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ logger.info(f"Using device: {self.device}")
15
+
16
+ # Initialize model containers
17
+ self.models = {}
18
+ self.tokenizers = {}
19
+
20
+ def load_models(self):
21
+ """Load all required models"""
22
+ # Core abuse pattern detection model
23
+ self._load_model(
24
+ "abuse_patterns",
25
+ "SamanthaStorm/tether-multilabel-v6",
26
+ is_multilabel=True
27
+ )
28
+
29
+ # Sentiment model
30
+ self._load_model(
31
+ "sentiment",
32
+ "SamanthaStorm/tether-sentiment-v3",
33
+ is_multilabel=False
34
+ )
35
+
36
+ # DARVO model
37
+ self._load_model(
38
+ "darvo",
39
+ "SamanthaStorm/tether-darvo-regressor-v1",
40
+ is_multilabel=False,
41
+ is_regression=True
42
+ )
43
+
44
+ # Boundary health model
45
+ self._load_model(
46
+ "boundary",
47
+ "SamanthaStorm/healthy-boundary-predictor",
48
+ is_multilabel=False
49
+ )
50
+
51
+ # Intent analyzer model
52
+ self._load_model(
53
+ "intent",
54
+ "SamanthaStorm/intentanalyzer",
55
+ is_multilabel=False
56
+ )
57
+
58
+ # Emotion model
59
+ try:
60
+ from transformers import pipeline
61
+ self.emotion_pipeline = pipeline(
62
+ "text-classification",
63
+ model="j-hartmann/emotion-english-distilroberta-base",
64
+ return_all_scores=True,
65
+ top_k=None,
66
+ truncation=True,
67
+ device=0 if torch.cuda.is_available() else -1
68
+ )
69
+ logger.info("Emotion pipeline loaded successfully")
70
+ except Exception as e:
71
+ logger.error(f"Error loading emotion pipeline: {e}")
72
+ self.emotion_pipeline = None
73
+
74
+ logger.info("All models loaded successfully")
75
+
76
+ def _load_model(self, name, model_path, is_multilabel=False, is_regression=False):
77
+ """Helper to load a model and its tokenizer"""
78
+ try:
79
+ logger.info(f"Loading {name} model from {model_path}")
80
+ self.models[name] = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
81
+ self.tokenizers[name] = AutoTokenizer.from_pretrained(model_path, use_fast=False)
82
+
83
+ # Store model metadata
84
+ self.models[name].is_multilabel = is_multilabel
85
+ self.models[name].is_regression = is_regression
86
+
87
+ logger.info(f"{name} model loaded successfully")
88
+ except Exception as e:
89
+ logger.error(f"Error loading {name} model: {e}")
90
+ raise
91
+
92
+ def predict_abuse_patterns(self, text, thresholds):
93
+ """Predict abuse patterns with thresholds"""
94
+ if not text.strip():
95
+ return [], []
96
+
97
+ inputs = self._prepare_inputs("abuse_patterns", text)
98
+
99
+ with torch.no_grad():
100
+ outputs = self.models["abuse_patterns"](**inputs)
101
+
102
+ # Get sigmoid scores for multi-label classification
103
+ raw_scores = torch.sigmoid(outputs.logits.squeeze(0)).cpu().numpy()
104
+
105
+ # Get labels
106
+ labels = self.get_abuse_pattern_labels()
107
+
108
+ # Apply thresholds and return
109
+ predictions = list(zip(labels, raw_scores))
110
+ matched_scores = []
111
+ threshold_labels = []
112
+
113
+ for label, score in predictions:
114
+ if score > thresholds.get(label, 0.25):
115
+ threshold_labels.append(label)
116
+ weight = self.get_pattern_weight(label)
117
+ matched_scores.append((label, float(score), weight))
118
+
119
+ return threshold_labels, matched_scores
120
+
121
+ def predict_sentiment(self, text):
122
+ """Predict sentiment (supportive vs undermining)"""
123
+ if not text.strip():
124
+ return "neutral", 0.5
125
+
126
+ inputs = self._prepare_inputs("sentiment", text)
127
+
128
+ with torch.no_grad():
129
+ outputs = self.models["sentiment"](**inputs)
130
+ logits = outputs.logits[0]
131
+ probs = softmax(logits, dim=-1).cpu().numpy()
132
+
133
+ # Get sentiment labels
134
+ labels = ["supportive", "undermining"]
135
+ sentiment = labels[int(probs.argmax())]
136
+ confidence = float(probs.max())
137
+
138
+ return sentiment, confidence
139
+
140
+ def predict_darvo(self, text):
141
+ """Predict DARVO score"""
142
+ if not text.strip():
143
+ return 0.0
144
+
145
+ inputs = self._prepare_inputs("darvo", text)
146
+
147
+ with torch.no_grad():
148
+ logits = self.models["darvo"](**inputs).logits
149
+ score = float(sigmoid(logits.cpu()).item())
150
+
151
+ return score
152
+
153
+ def predict_boundary_health(self, text):
154
+ """Predict boundary health (1 for healthy, 0 for unhealthy)"""
155
+ if not text.strip():
156
+ return 0
157
+
158
+ inputs = self._prepare_inputs("boundary", text)
159
+
160
+ with torch.no_grad():
161
+ outputs = self.models["boundary"](**inputs)
162
+ predictions = softmax(outputs.logits, dim=-1)
163
+ predicted_class = torch.argmax(predictions, dim=-1).item()
164
+
165
+ return predicted_class
166
+
167
+ def predict_intent(self, text):
168
+ """Predict intent"""
169
+ if not text.strip():
170
+ return "neutral", 0.5
171
+
172
+ inputs = self._prepare_inputs("intent", text)
173
+
174
+ with torch.no_grad():
175
+ outputs = self.models["intent"](**inputs)
176
+ probs = softmax(outputs.logits, dim=-1).cpu().numpy()[0]
177
+
178
+ # Get intent labels (adjust based on actual model outputs)
179
+ labels = ["neutral", "manipulative", "supportive", "controlling"]
180
+ intent = labels[int(probs.argmax())]
181
+ confidence = float(probs.max())
182
+
183
+ return intent, confidence
184
+
185
+ def get_emotion_profile(self, text):
186
+ """Get emotion profile from text"""
187
+ if not text.strip() or not self.emotion_pipeline:
188
+ return {
189
+ "sadness": 0.0,
190
+ "joy": 0.0,
191
+ "neutral": 0.0,
192
+ "disgust": 0.0,
193
+ "anger": 0.0,
194
+ "fear": 0.0
195
+ }
196
+
197
+ try:
198
+ emotions = self.emotion_pipeline(text)
199
+ if isinstance(emotions, list) and isinstance(emotions[0], list):
200
+ emotion_scores = emotions[0]
201
+ return {e['label'].lower(): round(e['score'], 3) for e in emotion_scores}
202
+ return {}
203
+ except Exception as e:
204
+ logger.error(f"Error in get_emotion_profile: {e}")
205
+ return {
206
+ "sadness": 0.0,
207
+ "joy": 0.0,
208
+ "neutral": 0.0,
209
+ "disgust": 0.0,
210
+ "anger": 0.0,
211
+ "fear": 0.0
212
+ }
213
+
214
+ def _prepare_inputs(self, model_name, text):
215
+ """Prepare inputs for the model"""
216
+ inputs = self.tokenizers[model_name](
217
+ text,
218
+ return_tensors="pt",
219
+ truncation=True,
220
+ padding=True
221
+ )
222
+ return {k: v.to(self.device) for k, v in inputs.items()}
223
+
224
+ def get_abuse_pattern_labels(self):
225
+ """Get abuse pattern labels"""
226
+ return [
227
+ "recovery phase", "control", "gaslighting", "guilt tripping", "dismissiveness",
228
+ "blame shifting", "nonabusive", "projection", "insults",
229
+ "contradictory statements", "obscure language",
230
+ "veiled threats", "stalking language", "false concern",
231
+ "false equivalence", "future faking"
232
+ ]
233
+
234
+ def get_pattern_weight(self, label):
235
+ """Get pattern weight for scoring"""
236
+ weights = {
237
+ "recovery phase": 0.7,
238
+ "control": 1.4,
239
+ "gaslighting": 1.3,
240
+ "guilt tripping": 1.2,
241
+ "dismissiveness": 0.9,
242
+ "blame shifting": 1.0,
243
+ "projection": 0.5,
244
+ "insults": 1.4,
245
+ "contradictory statements": 1.0,
246
+ "obscure language": 0.9,
247
+ "nonabusive": 0.0,
248
+ "veiled threats": 1.6,
249
+ "stalking language": 1.8,
250
+ "false concern": 1.1,
251
+ "false equivalence": 1.3,
252
+ "future faking": 0.8
253
+ }
254
+ return weights.get(label, 1.0)