TruthCheck / src /app.py
KhaqanNasir's picture
Update src/app.py
d49dd8a verified
raw
history blame
15.9 kB
import streamlit as st
import torch
import pandas as pd
import numpy as np
from pathlib import Path
import sys
import plotly.express as px
import plotly.graph_objects as go
from transformers import BertTokenizer
import nltk
# Download required NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
try:
nltk.data.find('corpora/wordnet')
except LookupError:
nltk.download('wordnet')
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from src.models.hybrid_model import HybridFakeNewsDetector
from src.config.config import *
from src.data.preprocessor import TextPreprocessor
# Custom CSS for enhanced, modern styling
st.markdown("""
<style>
/* Import Google Fonts */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
/* Global Styles */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
.stApp {
font-family: 'Inter', sans-serif;
background: #ffffff;
min-height: 100vh;
color: #1a202c;
}
/* Hide Streamlit elements */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stDeployButton {display: none;}
header {visibility: hidden;}
.stApp > header {visibility: hidden;}
/* Container */
.container {
max-width: 1280px;
margin: 0 auto;
padding: 2rem 1.5rem;
}
/* Header */
.header {
background: #ffffff;
padding: 1rem 2rem;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.08);
position: sticky;
top: 0;
z-index: 1000;
}
.header-title {
font-size: 2rem;
font-weight: 800;
color: #1a202c;
display: flex;
align-items: center;
gap: 0.5rem;
}
/* Hero Section */
.hero {
display: flex;
align-items: center;
gap: 3rem;
margin-bottom: 4rem;
background: linear-gradient(135deg, #f8fafc 0%, #edf2f7 100%);
padding: 4rem 2rem;
border-radius: 16px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05);
}
.hero-left {
flex: 1;
padding: 1rem;
}
.hero-right {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
}
.hero-right img {
max-width: 100%;
height: auto;
border-radius: 12px;
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.1);
transition: transform 0.3s ease;
}
.hero-right img:hover {
transform: scale(1.02);
}
.hero-title {
font-size: 3rem;
font-weight: 800;
color: #1a202c;
margin-bottom: 1.5rem;
line-height: 1.2;
}
.hero-text {
font-size: 1.2rem;
color: #4a5568;
line-height: 1.7;
max-width: 500px;
}
/* About Section */
.about-section {
margin-bottom: 3rem;
text-align: center;
padding: 2rem;
}
.about-title {
font-size: 2.2rem;
font-weight: 700;
color: #1a202c;
margin-bottom: 1rem;
}
.about-text {
font-size: 1.1rem;
color: #4a5568;
line-height: 1.6;
max-width: 700px;
margin: 0 auto;
}
/* Input Section */
.input-container {
max-width: 900px;
margin: 0 auto;
padding: 1.5rem;
background: #ffffff;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05);
}
.stTextArea > div > div > textarea {
border-radius: 10px !important;
border: 1px solid #d1d5db !important;
padding: 1.2rem !important;
font-size: 1rem !important;
font-family: 'Inter', sans-serif !important;
background: #f9fafb !important;
min-height: 180px !important;
transition: all 0.3s ease !important;
}
.stTextArea > div > div > textarea:focus {
border-color: #6366f1 !important;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1) !important;
outline: none !important;
}
.stTextArea > div > div > textarea::placeholder {
color: #9ca3af !important;
}
/* Button Styling */
.stButton > button {
background: linear-gradient(135deg, #6366f1 0%, #4f46e5 100%) !important;
color: white !important;
border-radius: 10px !important;
padding: 0.8rem 2.5rem !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
font-family: 'Inter', sans-serif !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 12px rgba(99, 102, 241, 0.3) !important;
width: 100% !important;
border: none !important;
}
.stButton > button:hover {
background: linear-gradient(135deg, #4f46e5 0%, #4338ca 100%) !important;
transform: translateY(-2px) !important;
box-shadow: 0 6px 16px rgba(99, 102, 241, 0.4) !important;
}
/* Results Section */
.results-container {
margin-top: 2rem;
padding: 2rem;
background: #ffffff;
border-radius: 12px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.05);
}
.result-card {
padding: 1.5rem;
border-radius: 10px;
border-left: 5px solid transparent;
margin-bottom: 1rem;
}
.fake-news {
background: #fef2f2;
border-left-color: #ef4444;
}
.real-news {
background: #ecfdf5;
border-left-color: #10b981;
}
.prediction-badge {
font-weight: 600;
font-size: 1.1rem;
margin-bottom: 1rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.confidence-score {
font-weight: 600;
margin-left: auto;
font-size: 1.1rem;
}
/* Chart Containers */
.chart-container {
padding: 1.5rem;
border-radius: 10px;
background: #ffffff;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
margin: 1.5rem 0;
}
/* Footer */
.footer {
margin-top: 4rem;
padding: 1.5rem;
text-align: center;
border-top: 1px solid #e5e7eb;
background: #f8fafc;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model_and_tokenizer():
"""Load the model and tokenizer (cached)."""
model = HybridFakeNewsDetector(
bert_model_name=BERT_MODEL_NAME,
lstm_hidden_size=LSTM_HIDDEN_SIZE,
lstm_num_layers=LSTM_NUM_LAYERS,
dropout_rate=DROPOUT_RATE
)
state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
model_state_dict = model.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
return model, tokenizer
@st.cache_resource
def get_preprocessor():
"""Get the text preprocessor (cached)."""
return TextPreprocessor()
def predict_news(text):
"""Predict if the given news is fake or real."""
model, tokenizer = load_model_and_tokenizer()
preprocessor = get_preprocessor()
processed_text = preprocessor.preprocess_text(text)
encoding = tokenizer.encode_plus(
processed_text,
add_special_tokens=True,
max_length=MAX_SEQUENCE_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = model(
encoding['input_ids'],
encoding['attention_mask']
)
probabilities = torch.softmax(outputs['logits'], dim=1)
prediction = torch.argmax(outputs['logits'], dim=1)
attention_weights = outputs['attention_weights']
attention_weights_np = attention_weights[0].cpu().numpy()
return {
'prediction': prediction.item(),
'label': 'FAKE' if prediction.item() == 1 else 'REAL',
'confidence': torch.max(probabilities, dim=1)[0].item(),
'probabilities': {
'REAL': probabilities[0][0].item(),
'FAKE': probabilities[0][1].item()
},
'attention_weights': attention_weights_np
}
def plot_confidence(probabilities):
"""Plot prediction confidence with simplified styling."""
fig = go.Figure(data=[
go.Bar(
x=list(probabilities.keys()),
y=list(probabilities.values()),
text=[f'{p:.1%}' for p in probabilities.values()],
textposition='auto',
marker=dict(
color=['#10b981', '#ef4444'],
line=dict(color='#ffffff', width=1),
),
)
])
fig.update_layout(
title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 20}},
xaxis=dict(title='Classification', titlefont={'size': 14}, tickfont={'size': 12}),
yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 14}, tickfont={'size': 12}),
template='plotly_white',
height=350,
margin=dict(t=80, b=80)
)
return fig
def plot_attention(text, attention_weights):
"""Plot attention weights with simplified styling."""
tokens = text.split()[:20]
attention_weights = attention_weights[:len(tokens)]
if isinstance(attention_weights, (list, np.ndarray)):
attention_weights = np.array(attention_weights).flatten()
normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights
colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights]
fig = go.Figure(data=[
go.Bar(
x=tokens,
y=attention_weights,
text=[f'{float(w):.3f}' for w in attention_weights],
textposition='auto',
marker=dict(color=colors),
)
])
fig.update_layout(
title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 20}},
xaxis=dict(title='Words', tickangle=45, titlefont={'size': 14}, tickfont={'size': 12}),
yaxis=dict(title='Attention Score', titlefont={'size': 14}, tickfont={'size': 12}),
template='plotly_white/pubs/DeepSearch/2025-07-25/4f0e5e9c-7e3f-4d87-9b50-3d7f1c7f5e6a.txt',
height=400,
margin=dict(t=80, b=100)
)
return fig
def main():
# Header
st.markdown("""
<div class="header">
<div class="container">
<h1 class="header-title">🛡️ TruthCheck</h1>
</div>
</div>
""", unsafe_allow_html=True)
# Hero Section
st.markdown("""
<div class="container">
<div class="hero">
<div class="hero-left">
<h2 class="hero-title">Instant Fake News Detection</h2>
<p class="hero-text">
Discover the truth behind news articles with our cutting-edge AI. Powered by a hybrid BERT-BiLSTM model, TruthCheck delivers fast, accurate, and transparent analysis of news authenticity.
</p>
</div>
<div class="hero-right">
<img src="/hero.png" alt="TruthCheck Illustration">
</div>
</div>
</div>
""", unsafe_allow_html=True)
# About Section
st.markdown("""
<div class="container">
<div class="about-section">
<h2 class="about-title">About TruthCheck</h2>
<p class="about-text">
TruthCheck combines advanced BERT and BiLSTM technologies to detect fake news with over 95% accuracy. Paste any news article below to receive a detailed analysis, including confidence scores and attention insights, in seconds.
</p>
</div>
</div>
""", unsafe_allow_html=True)
# Input Section
st.markdown('<div class="container"><div class="input-container">', unsafe_allow_html=True)
news_text = st.text_area(
"Analyze a News Article",
height=180,
placeholder="Paste your news article here for instant AI analysis...",
key="news_input"
)
st.markdown('</div>', unsafe_allow_html=True)
# Analyze Button
st.markdown('<div class="container">', unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
analyze_button = st.button("🔍 Analyze Now", key="analyze_button")
st.markdown('</div>', unsafe_allow_html=True)
if analyze_button:
if news_text and len(news_text.strip()) > 10:
with st.spinner("Analyzing article..."):
try:
result = predict_news(news_text)
st.markdown('<div class="container"><div class="results-container">', unsafe_allow_html=True)
# Prediction Result
col1, col2 = st.columns([1, 1], gap="medium")
with col1:
if result['label'] == 'FAKE':
st.markdown(f'''
<div class="result-card fake-news">
<div class="prediction-badge">🚨 Fake News Detected <span class="confidence-score">{result["confidence"]:.1%}</span></div>
<p>Our AI has identified this content as likely misinformation based on linguistic patterns, structural analysis, and content inconsistencies.</p>
</div>
''', unsafe_allow_html=True)
else:
st.markdown(f'''
<div class="result-card real-news">
<div class="prediction-badge">✅ Authentic News <span class="confidence-score">{result["confidence"]:.1%}</span></div>
<p>This content appears to be legitimate based on professional writing style, factual consistency, and structural integrity.</p>
</div>
''', unsafe_allow_html=True)
with col2:
st.markdown('<div class="chart-container">', unsafe_allow_html=True)
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
# Attention Analysis
st.markdown('<div class="chart-container">', unsafe_allow_html=True)
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
st.markdown('</div></div></div>', unsafe_allow_html=True)
except Exception as e:
st.markdown('<div class="container">', unsafe_allow_html=True)
st.error(f"Error: {str(e)}. Please try again or contact support.")
st.markdown('</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="container">', unsafe_allow_html=True)
st.error("Please enter a news article (at least 10 words) for analysis.")
st.markdown('</div>', unsafe_allow_html=True)
# Footer
st.markdown("""
<div class="footer">
<p style="text-align: center; font-weight: 600; font-size: 16px;">💻 Developed with ❤️ using Streamlit | © 2025</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()