Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
import sys | |
import plotly.graph_objects as go | |
from transformers import BertTokenizer | |
import nltk | |
# Download required NLTK data | |
nltk_data = { | |
'tokenizers/punkt': 'punkt', | |
'corpora/stopwords': 'stopwords', | |
'tokenizers/punkt_tab': 'punkt_tab', | |
'corpora/wordnet': 'wordnet' | |
} | |
for resource, package in nltk_data.items(): | |
try: | |
nltk.data.find(resource) | |
except LookupError: | |
nltk.download(package) | |
# Add project root to Python path | |
project_root = Path(__file__).parent.parent | |
sys.path.append(str(project_root)) | |
from src.models.hybrid_model import HybridFakeNewsDetector | |
from src.config.config import BERT_MODEL_NAME, LSTM_HIDDEN_SIZE, LSTM_NUM_LAYERS, DROPOUT_RATE, SAVED_MODELS_DIR, MAX_SEQUENCE_LENGTH | |
from src.data.preprocessor import TextPreprocessor | |
# Custom CSS with Poppins font | |
st.markdown(""" | |
<style> | |
@import url('https://fonts.googleapis.com/css2?family=Poppins:wght@200;300;400;500;600;700&display=swap'); | |
* { | |
font-family: 'Poppins', sans-serif !important; | |
box-sizing: border-box; | |
} | |
.stApp { | |
background: #ffffff; | |
min-height: 100vh; | |
color: #1f2a44; | |
} | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stDeployButton {display: none;} | |
header {visibility: hidden;} | |
.stApp > header {visibility: hidden;} | |
/* Main Container */ | |
.main-container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 1rem 2rem; | |
} | |
/* Header Section */ | |
.header-section { | |
text-align: center; | |
margin-bottom: 2.5rem; | |
padding: 1.5rem 0; | |
} | |
.header-title { | |
font-size: 2.25rem; | |
font-weight: 700; | |
color: #1f2a44; | |
margin: 0; | |
} | |
/* Hero Section */ | |
.hero { | |
display: flex; | |
align-items: center; | |
gap: 2rem; | |
margin-bottom: 2rem; | |
padding: 0 1rem; | |
} | |
.hero-left { | |
flex: 1; | |
padding: 1.5rem; | |
} | |
.hero-right { | |
flex: 1; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
} | |
.hero-right img { | |
max-width: 100%; | |
height: auto; | |
border-radius: 8px; | |
object-fit: cover; | |
} | |
.hero-title { | |
font-size: 2.5rem; | |
font-weight: 700; | |
color: #1f2a44; | |
margin-bottom: 0.5rem; | |
} | |
.hero-text { | |
font-size: 1rem; | |
color: #6b7280; | |
line-height: 1.6; | |
max-width: 450px; | |
} | |
/* About Section */ | |
.about-section { | |
margin-bottom: 2rem; | |
text-align: center; | |
padding: 0 1rem; | |
} | |
.about-title { | |
font-size: 1.75rem; | |
font-weight: 600; | |
color: #1f2a44; | |
margin-bottom: 0.5rem; | |
} | |
.about-text { | |
font-size: 0.95rem; | |
color: #6b7280; | |
line-height: 1.6; | |
max-width: 600px; | |
margin: 0 auto; | |
} | |
/* Input Section */ | |
.input-container { | |
max-width: 800px; | |
margin: 0 auto; | |
} | |
.stTextArea > div > div > textarea { | |
border-radius: 8px !important; | |
border: 1px solid #d1d5db !important; | |
padding: 1rem !important; | |
font-size: 1rem !important; | |
background: #ffffff !important; | |
min-height: 150px !important; | |
transition: all 0.2s ease !important; | |
} | |
.stTextArea > div > div > textarea:focus { | |
border-color: #6366f1 !important; | |
box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1) !important; | |
outline: none !important; | |
} | |
.stTextArea > div > div > textarea::placeholder { | |
color: #9ca3af !important; | |
} | |
/* Button Styling */ | |
.stButton > button { | |
background: #6366f1 !important; | |
color: white !important; | |
border-radius: 8px !important; | |
padding: 0.75rem 2rem !important; | |
font-size: 1rem !important; | |
font-weight: 600 !important; | |
transition: all 0.2s ease !important; | |
border: none !important; | |
width: 100% !important; | |
max-width: 300px; | |
} | |
.stButton > button:hover { | |
background: #4f46e5 !important; | |
transform: translateY(-1px) !important; | |
} | |
/* Results Section */ | |
.results-container { | |
margin-top: 1rem; | |
padding: 1rem; | |
border-radius: 8px; | |
max-width: 1200px; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
.result-card { | |
padding: 1rem; | |
border-radius: 8px; | |
border-left: 4px solid transparent; | |
margin-bottom: 1rem; | |
} | |
.fake-news { | |
background: #fef2f2; | |
border-left-color: #ef4444; | |
} | |
.real-news { | |
background: #ecfdf5; | |
border-left-color: #10b981; | |
} | |
.prediction-badge { | |
font-weight: 600; | |
font-size: 1rem; | |
margin-bottom: 0.5rem; | |
display: flex; | |
align-items: center; | |
gap: 0.5rem; | |
} | |
.confidence-score { | |
font-weight: 600; | |
margin-left: auto; | |
font-size: 1rem; | |
} | |
/* Chart Containers */ | |
.chart-container { | |
padding: 1rem; | |
border-radius: 8px; | |
margin: 1rem 0; | |
max-width: 1200px; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
/* Footer */ | |
.footer { | |
border-top: 1px solid #e5e7eb; | |
padding: 1.5rem 0; | |
text-align: center; | |
max-width: 1200px; | |
margin: 2rem auto 0; | |
} | |
/* Responsive Design */ | |
@media (max-width: 1024px) { | |
.hero { | |
flex-direction: column; | |
text-align: center; | |
} | |
.hero-right img { | |
max-width: 80%; | |
} | |
} | |
@media (max-width: 768px) { | |
.header-title { | |
font-size: 1.75rem; | |
} | |
.hero-title { | |
font-size: 2rem; | |
} | |
.hero-text { | |
font-size: 0.9rem; | |
} | |
.about-title { | |
font-size: 1.5rem; | |
} | |
.about-text { | |
font-size: 0.9rem; | |
} | |
} | |
@media (max-width: 480px) { | |
.header-title { | |
font-size: 1.5rem; | |
} | |
.hero-title { | |
font-size: 1.75rem; | |
} | |
.hero-text { | |
font-size: 0.85rem; | |
} | |
.about-title { | |
font-size: 1.25rem; | |
} | |
.about-text { | |
font-size: 0.85rem; | |
} | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def load_model_and_tokenizer() -> tuple[HybridFakeNewsDetector, BertTokenizer] | tuple[None, None]: | |
"""Load the model and tokenizer (cached).""" | |
try: | |
model = HybridFakeNewsDetector( | |
bert_model_name=BERT_MODEL_NAME, | |
lstm_hidden_size=LSTM_HIDDEN_SIZE, | |
lstm_num_layers=LSTM_NUM_LAYERS, | |
dropout_rate=DROPOUT_RATE | |
) | |
model_path = SAVED_MODELS_DIR / "final_model.pt" | |
if not model_path.exists(): | |
st.error("Model file not found. Please ensure 'final_model.pt' is in the models/saved directory.") | |
return None, None | |
state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
model_state_dict = model.state_dict() | |
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | |
model.load_state_dict(filtered_state_dict, strict=False) | |
model.eval() | |
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model or tokenizer: {str(e)}") | |
return None, None | |
def get_preprocessor() -> TextPreprocessor | None: | |
"""Get the text preprocessor (cached).""" | |
try: | |
return TextPreprocessor() | |
except Exception as e: | |
st.error(f"Error initializing preprocessor: {str(e)}") | |
return None | |
def predict_news(text: str) -> dict | None: | |
"""Predict if the given news is fake or real.""" | |
model, tokenizer = load_model_and_tokenizer() | |
if model is None or tokenizer is None: | |
return None | |
preprocessor = get_preprocessor() | |
if preprocessor is None: | |
return None | |
try: | |
processed_text = preprocessor.preprocess_text(text) | |
encoding = tokenizer.encode_plus( | |
processed_text, | |
add_special_tokens=True, | |
max_length=MAX_SEQUENCE_LENGTH, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
outputs = model( | |
encoding['input_ids'], | |
encoding['attention_mask'] | |
) | |
probabilities = torch.softmax(outputs['logits'], dim=1) | |
prediction = torch.argmax(outputs['logits'], dim=1) | |
attention_weights = outputs.get('attention_weights', torch.zeros(1)) | |
attention_weights_np = attention_weights[0].cpu().numpy() | |
return { | |
'prediction': prediction.item(), | |
'label': 'FAKE' if prediction.item() == 1 else 'REAL', | |
'confidence': torch.max(probabilities, dim=1)[0].item(), | |
'probabilities': { | |
'REAL': probabilities[0][0].item(), | |
'FAKE': probabilities[0][1].item() | |
}, | |
'attention_weights': attention_weights_np | |
} | |
except Exception as e: | |
st.error(f"Prediction error: {str(e)}") | |
return None | |
def plot_confidence(probabilities: dict) -> go.Figure: | |
"""Plot prediction confidence with simplified styling.""" | |
if not probabilities or not isinstance(probabilities, dict): | |
return go.Figure() | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=list(probabilities.keys()), | |
y=list(probabilities.values()), | |
text=[f'{p:.1%}' for p in probabilities.values()], | |
textposition='auto', | |
marker=dict( | |
color=['#10b981', '#ef4444'], | |
line=dict(color='#ffffff', width=1), | |
), | |
) | |
]) | |
fig.update_layout( | |
title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, | |
xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}), | |
yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}), | |
template='plotly_white', | |
height=300, | |
margin=dict(t=60, b=60) | |
) | |
return fig | |
def plot_attention(text: str, attention_weights: np.ndarray) -> go.Figure: | |
"""Plot attention weights with simplified styling.""" | |
if not text or not attention_weights.size: | |
return go.Figure() | |
tokens = text.split()[:20] | |
attention_weights = attention_weights[:len(tokens)] | |
if isinstance(attention_weights, (list, np.ndarray)): | |
attention_weights = np.array(attention_weights).flatten() | |
normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights | |
colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights] | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=tokens, | |
y=attention_weights, | |
text=[f'{float(w):.3f}' for w in attention_weights], | |
textposition='auto', | |
marker=dict(color=colors), | |
) | |
]) | |
fig.update_layout( | |
title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, | |
xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}), | |
yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}), | |
template='plotly_white', | |
height=350, | |
margin=dict(t=60, b=80) | |
) | |
return fig | |
def main(): | |
# Main Container | |
st.markdown('<div class="main-container">', unsafe_allow_html=True) | |
# Header Section | |
st.markdown(""" | |
<div class="header-section"> | |
<h1 class="header-title">🛡️ TruthCheck - Advanced Fake News Detector</h1> | |
</div> | |
""", unsafe_allow_html=True) | |
# Hero Section | |
st.markdown(""" | |
<div class="hero"> | |
<div class="hero-left"> | |
<h2 class="hero-title">Instant Fake News Detection</h2> | |
<p class="hero-text"> | |
Verify news articles with our AI-powered tool, driven by advanced BERT and BiLSTM models for accurate authenticity analysis. | |
</p> | |
</div> | |
<div class="hero-right"> | |
<img src="https://images.pexels.com/photos/267350/pexels-photo-267350.jpeg?auto=compress&cs=tinysrgb&w=500" alt="Fake News Illustration" onerror="this.src='https://via.placeholder.com/500x300.png?text=Fake+News+Illustration'"> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# About Section | |
st.markdown(""" | |
<div class="about-section"> | |
<h2 class="about-title">About TruthCheck</h2> | |
<p class="about-text"> | |
TruthCheck harnesses a hybrid BERT-BiLSTM model to detect fake news with high precision. Simply paste an article below to analyze its authenticity instantly. | |
</p> | |
</div> | |
""", unsafe_allow_html=True) | |
# Input Section | |
st.markdown('<div class="input-container">', unsafe_allow_html=True) | |
news_text = st.text_area( | |
"Analyze a News Article", | |
height=150, | |
placeholder="Paste your news article here for instant AI analysis...", | |
key="news_input" | |
) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Analyze Button | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
analyze_button = st.button("🔍 Analyze Now", key="analyze_button") | |
if analyze_button: | |
if news_text and len(news_text.strip()) > 10: | |
with st.spinner("Analyzing article..."): | |
result = predict_news(news_text) | |
if result: | |
st.markdown('<div class="results-container">', unsafe_allow_html=True) | |
# Prediction Result | |
col1, col2 = st.columns([1, 1], gap="medium") | |
with col1: | |
if result['label'] == 'FAKE': | |
st.markdown(f''' | |
<div class="result-card fake-news"> | |
<div class="prediction-badge">🚨 Fake News Detected <span class="confidence-score">{result["confidence"]:.1%}</span></div> | |
<p>Our AI has identified this content as likely misinformation based on linguistic patterns and context.</p> | |
</div> | |
''', unsafe_allow_html=True) | |
else: | |
st.markdown(f''' | |
<div class="result-card real-news"> | |
<div class="prediction-badge">✅ Authentic News <span class="confidence-score">{result["confidence"]:.1%}</span></div> | |
<p>This content appears legitimate based on professional writing style and factual consistency.</p> | |
</div> | |
''', unsafe_allow_html=True) | |
with col2: | |
st.markdown('<div class="chart-container">', unsafe_allow_html=True) | |
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Attention Analysis | |
st.markdown('<div class="chart-container">', unsafe_allow_html=True) | |
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) | |
st.markdown('</div></div>', unsafe_allow_html=True) | |
else: | |
st.error("Please enter a news article (at least 10 words) for analysis.") | |
# Footer | |
st.markdown("---") | |
st.markdown( | |
'<p style="text-align: center; font-weight: 600; font-size: 16px;">💻 Developed with ❤️ using Streamlit | © 2025</p>', | |
unsafe_allow_html=True | |
) | |
st.markdown('</div>', unsafe_allow_html=True) # Close main-container | |
if __name__ == "__main__": | |
main() |