Spaces:
Running
Running
File size: 7,485 Bytes
876b12f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import logging
from typing import Dict, Any, List
from transformers import pipeline
from transformers import AutoTokenizer
import numpy as np
logger = logging.getLogger(__name__)
class HeadlineAnalyzer:
def __init__(self):
"""Initialize the NLI model for contradiction detection."""
self.nli_pipeline = pipeline("text-classification", model="roberta-large-mnli")
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
self.max_length = 512
def _split_content(self, headline: str, content: str) -> List[str]:
"""Split content into sections that fit within token limit."""
content_words = content.split()
sections = []
current_section = []
# Account for headline and [SEP] token in the max length
headline_tokens = len(self.tokenizer.encode(headline))
sep_tokens = len(self.tokenizer.encode("[SEP]")) - 2 # -2 because encode adds special tokens
max_content_tokens = self.max_length - headline_tokens - sep_tokens
# Process words into sections
for word in content_words:
current_section.append(word)
# Check if current section is approaching token limit
current_text = " ".join(current_section)
if len(self.tokenizer.encode(current_text)) >= max_content_tokens:
# Remove last word (it might make us go over limit)
current_section.pop()
sections.append(" ".join(current_section))
# Start new section with 20% overlap for context
overlap_start = max(0, len(current_section) - int(len(current_section) * 0.2))
current_section = current_section[overlap_start:]
current_section.append(word)
# Add any remaining content as the last section
if current_section:
sections.append(" ".join(current_section))
logger.info(f"""Content Splitting:
- Original content length: {len(content_words)} words
- Split into {len(sections)} sections
- Headline uses {headline_tokens} tokens
- Available tokens per section: {max_content_tokens}
""")
return sections
def _analyze_section(self, headline: str, section: str) -> Dict[str, float]:
"""Analyze a single section of content."""
input_text = f"{headline} [SEP] {section}"
result = self.nli_pipeline(input_text, top_k=None)
# Extract scores
scores = {item['label']: item['score'] for item in result}
logger.info("\nSection Analysis:")
logger.info("-"*30)
logger.info(f"Section preview: {section[:100]}...")
for label, score in scores.items():
logger.info(f"Label: {label:<12} Score: {score:.3f}")
return scores
def analyze(self, headline: str, content: str) -> Dict[str, Any]:
"""Analyze how well the headline matches the content using an AI model."""
try:
logger.info("\n" + "="*50)
logger.info("HEADLINE ANALYSIS STARTED")
logger.info("="*50)
# Handle empty inputs
if not headline.strip() or not content.strip():
logger.warning("Empty headline or content provided")
return {
"headline_vs_content_score": 0,
"entailment_score": 0,
"contradiction_score": 0,
"contradictory_phrases": []
}
# Split content if too long
content_tokens = len(self.tokenizer.encode(content))
if content_tokens > self.max_length:
logger.warning(f"""
Content Length Warning:
- Total tokens: {content_tokens}
- Max allowed: {self.max_length}
- Splitting into sections...
""")
sections = self._split_content(headline, content)
# Analyze each section
section_scores = []
for i, section in enumerate(sections, 1):
logger.info(f"\nAnalyzing section {i}/{len(sections)}")
scores = self._analyze_section(headline, section)
section_scores.append(scores)
# Aggregate scores across sections
# Use max contradiction (if any section strongly contradicts, that's important)
# Use mean entailment (overall support across sections)
# Use mean neutral (general neutral tone across sections)
entailment_score = np.mean([s.get('ENTAILMENT', 0) for s in section_scores])
contradiction_score = np.max([s.get('CONTRADICTION', 0) for s in section_scores])
neutral_score = np.mean([s.get('NEUTRAL', 0) for s in section_scores])
logger.info("\nAggregated Scores Across Sections:")
logger.info("-"*30)
logger.info(f"Mean Entailment: {entailment_score:.3f}")
logger.info(f"Max Contradiction: {contradiction_score:.3f}")
logger.info(f"Mean Neutral: {neutral_score:.3f}")
else:
# Single section analysis
scores = self._analyze_section(headline, content)
entailment_score = scores.get('ENTAILMENT', 0)
contradiction_score = scores.get('CONTRADICTION', 0)
neutral_score = scores.get('NEUTRAL', 0)
# Compute final consistency score
final_score = (
(entailment_score * 0.6) + # Base score from entailment
(neutral_score * 0.3) + # Neutral is acceptable
((1 - contradiction_score) * 0.1) # Small penalty for contradiction
) * 100
# Log final results
logger.info("\nFinal Analysis Results:")
logger.info("-"*30)
logger.info(f"Headline: {headline}")
logger.info(f"Content Length: {content_tokens} tokens")
logger.info("\nFinal Scores:")
logger.info(f"{'Entailment:':<15} {entailment_score:.3f}")
logger.info(f"{'Neutral:':<15} {neutral_score:.3f}")
logger.info(f"{'Contradiction:':<15} {contradiction_score:.3f}")
logger.info(f"\nFinal Score: {final_score:.1f}%")
logger.info("="*50 + "\n")
return {
"headline_vs_content_score": round(final_score, 1),
"entailment_score": round(entailment_score, 2),
"contradiction_score": round(contradiction_score, 2),
"contradictory_phrases": []
}
except Exception as e:
logger.error("\nHEADLINE ANALYSIS ERROR")
logger.error("-"*30)
logger.error(f"Error Type: {type(e).__name__}")
logger.error(f"Error Message: {str(e)}")
logger.error("Stack Trace:", exc_info=True)
logger.error("="*50 + "\n")
return {
"headline_vs_content_score": 0,
"entailment_score": 0,
"contradiction_score": 0,
"contradictory_phrases": []
} |