Spaces:
Runtime error
Runtime error
Create models.py
Browse files
models.py
CHANGED
@@ -1,29 +1,4 @@
|
|
1 |
-
|
2 |
-
"""Predict intent using multilabel classification"""
|
3 |
-
if not text.strip():
|
4 |
-
return "neutral", 0.5
|
5 |
-
|
6 |
-
# Check if intent model is available
|
7 |
-
if "intent" not in self.models:
|
8 |
-
logger.warning("Intent model not available, returning neutral intent")
|
9 |
-
return "neutral", 0.5
|
10 |
-
|
11 |
-
try:
|
12 |
-
inputs = self._prepare_inputs("intent", text)
|
13 |
-
|
14 |
-
with torch.no_grad():
|
15 |
-
outputs = self.models["intent"](**inputs)
|
16 |
-
|
17 |
-
if self.models["intent"].is_multilabel:
|
18 |
-
# For multilabel, use sigmoid and get all scores
|
19 |
-
probs = torch.sigmoid(outputs.logits).cpu().numpy()[0]
|
20 |
-
logger.debug(f"Intent model raw probabilities: {probs}")
|
21 |
-
|
22 |
-
# Get intent labels - you may need to adjust these based on your actual model
|
23 |
-
intent_labels = ["neutral", "manipulative", "supportive", "controlling", "threatening", "gaslighting"]
|
24 |
-
|
25 |
-
# Debug: print all scores
|
26 |
-
|
27 |
import torch.nn as nn
|
28 |
import logging
|
29 |
import os
|
@@ -51,10 +26,6 @@ class MultiLabelIntentClassifier(nn.Module):
|
|
51 |
logits = self.classifier(pooled_output)
|
52 |
return logits
|
53 |
|
54 |
-
# Set up logging
|
55 |
-
logging.basicConfig(level=logging.INFO)
|
56 |
-
logger = logging.getLogger(__name__)
|
57 |
-
|
58 |
class ModelManager:
|
59 |
def __init__(self, device=None):
|
60 |
"""Initialize model manager with device detection"""
|
@@ -270,7 +241,6 @@ class ModelManager:
|
|
270 |
"""Load emotion pipeline with retry logic"""
|
271 |
for attempt in range(max_retries):
|
272 |
try:
|
273 |
-
from transformers import pipeline
|
274 |
logger.info(f"Loading emotion pipeline (attempt {attempt+1}/{max_retries})")
|
275 |
|
276 |
self.emotion_pipeline = pipeline(
|
@@ -289,9 +259,8 @@ class ModelManager:
|
|
289 |
logger.error(f"Error loading emotion pipeline (attempt {attempt+1}): {e}")
|
290 |
time.sleep(2) # Wait before retry
|
291 |
|
292 |
-
|
293 |
-
|
294 |
-
return False
|
295 |
|
296 |
def predict_fallacy(self, text):
|
297 |
"""Predict logical fallacy using FallacyFinder model"""
|
|
|
1 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
import logging
|
4 |
import os
|
|
|
26 |
logits = self.classifier(pooled_output)
|
27 |
return logits
|
28 |
|
|
|
|
|
|
|
|
|
29 |
class ModelManager:
|
30 |
def __init__(self, device=None):
|
31 |
"""Initialize model manager with device detection"""
|
|
|
241 |
"""Load emotion pipeline with retry logic"""
|
242 |
for attempt in range(max_retries):
|
243 |
try:
|
|
|
244 |
logger.info(f"Loading emotion pipeline (attempt {attempt+1}/{max_retries})")
|
245 |
|
246 |
self.emotion_pipeline = pipeline(
|
|
|
259 |
logger.error(f"Error loading emotion pipeline (attempt {attempt+1}): {e}")
|
260 |
time.sleep(2) # Wait before retry
|
261 |
|
262 |
+
logger.error("Failed to load emotion pipeline after all retries")
|
263 |
+
return False
|
|
|
264 |
|
265 |
def predict_fallacy(self, text):
|
266 |
"""Predict logical fallacy using FallacyFinder model"""
|