TruthCheck / src /app.py
adnaan05's picture
Update src/app.py (#21)
2ddd46f verified
raw
history blame
15.1 kB
import streamlit as st
import torch
import pandas as pd
import numpy as np
from pathlib import Path
import sys
import plotly.express as px
import plotly.graph_objects as go
from transformers import BertTokenizer
import nltk
# Download required NLTK data
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab')
try:
nltk.data.find('corpora/wordnet')
except LookupError:
nltk.download('wordnet')
# Add project root to Python path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
from src.models.hybrid_model import HybridFakeNewsDetector
from src.config.config import *
from src.data.preprocessor import TextPreprocessor
# Custom CSS for streamlined styling with sidebar
st.markdown("""
<style>
/* Import Google Fonts */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap');
/* Global Styles */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
.stApp {
font-family: 'Inter', sans-serif;
background: #f8fafc;
min-height: 100vh;
color: #1a202c;
}
/* Ensure sidebar is visible */
#MainMenu {visibility: visible;}
footer {visibility: hidden;}
.stDeployButton {display: none;}
header {visibility: hidden;}
.stApp > header {visibility: hidden;}
/* Container */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 1rem;
}
/* Header */
.header {
padding: 1rem 0;
text-align: center;
}
.header-title {
font-size: 2rem;
font-weight: 800;
color: #1a202c;
display: inline-flex;
align-items: center;
gap: 0.5rem;
}
/* Hero Section */
.hero {
display: flex;
align-items: center;
gap: 2rem;
margin-bottom: 2rem;
}
.hero-left {
flex: 1;
padding: 1rem;
}
.hero-right {
flex: 1;
display: flex;
align-items: center;
justify-content: center;
}
.hero-right img {
max-width: 100%;
height: auto;
border-radius: 8px;
}
.hero-title {
font-size: 2.5rem;
font-weight: 700;
color: #1a202c;
margin-bottom: 0.5rem;
}
.hero-text {
font-size: 1rem;
color: #4a5568;
line-height: 1.5;
max-width: 450px;
}
/* About Section */
.about-section {
margin-bottom: 2rem;
text-align: center;
}
.about-title {
font-size: 1.8rem;
font-weight: 600;
color: #1a202c;
margin-bottom: 0.5rem;
}
.about-text {
font-size: 1rem;
color: #4a5568;
line-height: 1.5;
max-width: 600px;
margin: 0 auto;
}
/* Input Section */
.input-container {
max-width: 800px;
margin: 0 auto;
}
.stTextArea > div > div > textarea {
border-radius: 8px !important;
border: 1px solid #d1d5db !important;
padding: 1rem !important;
font-size: 1rem !important;
font-family: 'Inter', sans-serif !important;
background: #ffffff !important;
min-height: 150px !important;
transition: all 0.2s ease !important;
}
.stTextArea > div > div > textarea:focus {
border-color: #6366f1 !important;
box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.1) !important;
outline: none !important;
}
.stTextArea > div > div > textarea::placeholder {
color: #9ca3af !important;
}
/* Button Styling */
.stButton > button {
background: #6366f1 !important;
color: white !important;
border-radius: 8px !important;
padding: 0.75rem 2rem !important;
font-size: 1rem !important;
font-weight: 600 !important;
font-family: 'Inter', sans-serif !important;
transition: all 0.2s ease !important;
border: none !important;
width: 100% !important;
}
.stButton > button:hover {
background: #4f46e5 !important;
transform: translateY(-1px) !important;
}
/* Results Section */
.results-container {
margin-top: 1rem;
padding: 1rem;
border-radius: 8px;
}
.result-card {
padding: 1rem;
border-radius: 8px;
border-left: 4px solid transparent;
margin-bottom: 1rem;
}
.fake-news {
background: #fef2f2;
border-left-color: #ef4444;
}
.real-news {
background: #ecfdf5;
border-left-color: #10b981;
}
.prediction-badge {
font-weight: 600;
font-size: 1rem;
margin-bottom: 0.5rem;
display: flex;
align-items: center;
gap: 0.5rem;
}
.confidence-score {
font-weight: 600;
margin-left: auto;
font-size: 1rem;
}
/* Chart Containers */
.chart-container {
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
}
/* Footer */
.footer {
margin-top: 2rem;
padding: 1rem 0;
text-align: center;
border-top: 1px solid #e5e7eb;
}
/* Sidebar Styling */
.stSidebar {
background: #ffffff;
border-right: 1px solid #e5e7eb;
}
.stSidebar .sidebar-content {
padding: 1rem;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model_and_tokenizer():
"""Load the model and tokenizer (cached)."""
model = HybridFakeNewsDetector(
bert_model_name=BERT_MODEL_NAME,
lstm_hidden_size=LSTM_HIDDEN_SIZE,
lstm_num_layers=LSTM_NUM_LAYERS,
dropout_rate=DROPOUT_RATE
)
state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
model_state_dict = model.state_dict()
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
model.load_state_dict(filtered_state_dict, strict=False)
model.eval()
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
return model, tokenizer
@st.cache_resource
def get_preprocessor():
"""Get the text preprocessor (cached)."""
return TextPreprocessor()
def predict_news(text):
"""Predict if the given news is fake or real."""
model, tokenizer = load_model_and_tokenizer()
preprocessor = get_preprocessor()
processed_text = preprocessor.preprocess_text(text)
encoding = tokenizer.encode_plus(
processed_text,
add_special_tokens=True,
max_length=MAX_SEQUENCE_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
with torch.no_grad():
outputs = model(
encoding['input_ids'],
encoding['attention_mask']
)
probabilities = torch.softmax(outputs['logits'], dim=1)
prediction = torch.argmax(outputs['logits'], dim=1)
attention_weights = outputs['attention_weights']
attention_weights_np = attention_weights[0].cpu().numpy()
return {
'prediction': prediction.item(),
'label': 'FAKE' if prediction.item() == 1 else 'REAL',
'confidence': torch.max(probabilities, dim=1)[0].item(),
'probabilities': {
'REAL': probabilities[0][0].item(),
'FAKE': probabilities[0][1].item()
},
'attention_weights': attention_weights_np
}
def plot_confidence(probabilities):
"""Plot prediction confidence with simplified styling."""
fig = go.Figure(data=[
go.Bar(
x=list(probabilities.keys()),
y=list(probabilities.values()),
text=[f'{p:.1%}' for p in probabilities.values()],
textposition='auto',
marker=dict(
color=['#10b981', '#ef4444'],
line=dict(color='#ffffff', width=1),
),
)
])
fig.update_layout(
title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}},
xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}),
yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}),
template='plotly_white',
height=300,
margin=dict(t=60, b=60)
)
return fig
def plot_attention(text, attention_weights):
"""Plot attention weights with simplified styling."""
tokens = text.split()[:20]
attention_weights = attention_weights[:len(tokens)]
if isinstance(attention_weights, (list, np.ndarray)):
attention_weights = np.array(attention_weights).flatten()
normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights
colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights]
fig = go.Figure(data=[
go.Bar(
x=tokens,
y=attention_weights,
text=[f'{float(w):.3f}' for w in attention_weights],
textposition='auto',
marker=dict(color=colors),
)
])
fig.update_layout(
title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}},
xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}),
yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}),
template='plotly_white',
height=350,
margin=dict(t=60, b=80)
)
return fig
def main():
# Sidebar
with st.sidebar:
st.markdown("## TruthCheck Menu")
st.markdown("Navigate through the options below:")
st.button("Home", disabled=True)
st.button("Analyze News", key="nav_analyze")
st.button("About", key="nav_about")
st.markdown("---")
st.markdown("**Contact**")
st.markdown("πŸ“§ [email protected]")
# Header
st.markdown("""
<div class="header">
<div class="container">
<h1 class="header-title">πŸ›‘οΈ TruthCheck</h1>
</div>
</div>
""", unsafe_allow_html=True)
# Hero Section
st.markdown("""
<div class="container">
<div class="hero">
<div class="hero-left">
<h2 class="hero-title">Instant Fake News Detection</h2>
<p class="hero-text">
Verify news articles with our AI-powered tool, driven by BERT and BiLSTM for fast and accurate authenticity analysis.
</p>
</div>
<div class="hero-right">
<img src="hero.png" alt="TruthCheck Illustration">
</div>
</div>
</div>
""", unsafe_allow_html=True)
# About Section
st.markdown("""
<div class="container">
<div class="about-section">
<h2 class="about-title">About TruthCheck</h2>
<p class="about-text">
TruthCheck uses a hybrid BERT-BiLSTM model to detect fake news with high accuracy. Paste an article below for instant analysis.
</p>
</div>
</div>
""", unsafe_allow_html=True)
# Input Section
st.markdown('<div class="container"><div class="input-container">', unsafe_allow_html=True)
news_text = st.text_area(
"Analyze a News Article",
height=150,
placeholder="Paste your news article here for instant AI analysis...",
key="news_input"
)
st.markdown('</div>', unsafe_allow_html=True)
# Analyze Button
st.markdown('<div class="container">', unsafe_allow_html=True)
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
analyze_button = st.button("πŸ” Analyze Now", key="analyze_button")
st.markdown('</div>', unsafe_allow_html=True)
if analyze_button:
if news_text and len(news_text.strip()) > 10:
with st.spinner("Analyzing article..."):
try:
result = predict_news(news_text)
st.markdown('<div class="container"><div class="results-container">', unsafe_allow_html=True)
# Prediction Result
col1, col2 = st.columns([1, 1], gap="medium")
with col1:
if result['label'] == 'FAKE':
st.markdown(f'''
<div class="result-card fake-news">
<div class="prediction-badge">🚨 Fake News Detected <span class="confidence-score">{result["confidence"]:.1%}</span></div>
<p>Our AI has identified this content as likely misinformation based on linguistic patterns and content analysis.</p>
</div>
''', unsafe_allow_html=True)
else:
st.markdown(f'''
<div class="result-card real-news">
<div class="prediction-badge">βœ… Authentic News <span class="confidence-score">{result["confidence"]:.1%}</span></div>
<p>This content appears to be legitimate based on professional writing style and factual consistency.</p>
</div>
''', unsafe_allow_html=True)
with col2:
st.markdown('<div class="chart-container">', unsafe_allow_html=True)
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
# Attention Analysis
st.markdown('<div class="chart-container">', unsafe_allow_html=True)
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
st.markdown('</div></div></div>', unsafe_allow_html=True)
except Exception as e:
st.markdown('<div class="container">', unsafe_allow_html=True)
st.error(f"Error: {str(e)}. Please try again or contact support.")
st.markdown('</div>', unsafe_allow_html=True)
else:
st.markdown('<div class="container">', unsafe_allow_html=True)
st.error("Please enter a news article (at least 10 words) for analysis.")
st.markdown('</div>', unsafe_allow_html=True)
# Footer
st.markdown("""
<div class="footer">
<p style="text-align: center; font-weight: 600; font-size: 16px;">πŸ’» Developed with ❀️ using Streamlit | Β© 2025</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()