import streamlit as st import torch import pandas as pd import numpy as np from pathlib import Path import sys import plotly.express as px import plotly.graph_objects as go from transformers import BertTokenizer import nltk # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab') try: nltk.data.find('corpora/wordnet') except LookupError: nltk.download('wordnet') # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.models.hybrid_model import HybridFakeNewsDetector from src.config.config import * from src.data.preprocessor import TextPreprocessor # Custom CSS with Poppins font and increased font sizes st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_tokenizer(): """Load the model and tokenizer (cached).""" model = HybridFakeNewsDetector( bert_model_name=BERT_MODEL_NAME, lstm_hidden_size=LSTM_HIDDEN_SIZE, lstm_num_layers=LSTM_NUM_LAYERS, dropout_rate=DROPOUT_RATE ) state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} model.load_state_dict(filtered_state_dict, strict=False) model.eval() tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) return model, tokenizer @st.cache_resource def get_preprocessor(): """Get the text preprocessor (cached).""" return TextPreprocessor() def predict_news(text): """Predict if the given news is fake or real.""" model, tokenizer = load_model_and_tokenizer() preprocessor = get_preprocessor() processed_text = preprocessor.preprocess_text(text) encoding = tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_SEQUENCE_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): outputs = model( encoding['input_ids'], encoding['attention_mask'] ) probabilities = torch.softmax(outputs['logits'], dim=1) prediction = torch.argmax(outputs['logits'], dim=1) attention_weights = outputs['attention_weights'] attention_weights_np = attention_weights[0].cpu().numpy() return { 'prediction': prediction.item(), 'label': 'FAKE' if prediction.item() == 1 else 'REAL', 'confidence': torch.max(probabilities, dim=1)[0].item(), 'probabilities': { 'REAL': probabilities[0][0].item(), 'FAKE': probabilities[0][1].item() }, 'attention_weights': attention_weights_np } def plot_confidence(probabilities): """Plot prediction confidence with simplified styling.""" fig = go.Figure(data=[ go.Bar( x=list(probabilities.keys()), y=list(probabilities.values()), text=[f'{p:.1%}' for p in probabilities.values()], textposition='auto', marker=dict( color=['#10b981', '#ef4444'], line=dict(color='#ffffff', width=1), ), ) ]) fig.update_layout( title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}), yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}), template='plotly_white', height=300, margin=dict(t=60, b=60) ) return fig def plot_attention(text, attention_weights): """Plot attention weights with simplified styling.""" tokens = text.split()[:20] attention_weights = attention_weights[:len(tokens)] if isinstance(attention_weights, (list, np.ndarray)): attention_weights = np.array(attention_weights).flatten() normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights] fig = go.Figure(data=[ go.Bar( x=tokens, y=attention_weights, text=[f'{float(w):.3f}' for w in attention_weights], textposition='auto', marker=dict(color=colors), ) ]) fig.update_layout( title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}), yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}), template='plotly_white', height=350, margin=dict(t=60, b=80) ) return fig def main(): # Header st.markdown("""

🛡️ TruthCheck

""", unsafe_allow_html=True) # Hero Section st.markdown("""

Instant Fake News Detection

Verify news articles with our AI-powered tool, driven by BERT and BiLSTM for fast and accurate authenticity analysis.

Fake News Detector
""", unsafe_allow_html=True) # About Section st.markdown("""

About TruthCheck

TruthCheck uses a hybrid BERT-BiLSTM model to detect fake news with high accuracy. Paste an article below for instant analysis.

""", unsafe_allow_html=True) # Input Section st.markdown('
', unsafe_allow_html=True) news_text = st.text_area( "Analyze a News Article", height=150, placeholder="Paste your news article here for instant AI analysis...", key="news_input" ) st.markdown('
', unsafe_allow_html=True) # Analyze Button st.markdown('
', unsafe_allow_html=True) col1, col2, col3 = st.columns([1, 2, 1]) with col2: analyze_button = st.button("🔍 Analyze Now", key="analyze_button") st.markdown('
', unsafe_allow_html=True) if analyze_button: if news_text and len(news_text.strip()) > 10: with st.spinner("Analyzing article..."): try: result = predict_news(news_text) st.markdown('
', unsafe_allow_html=True) # Prediction Result col1, col2 = st.columns([1, 1], gap="medium") with col1: if result['label'] == 'FAKE': st.markdown(f'''
🚨 Fake News Detected {result["confidence"]:.1%}

Our AI has identified this content as likely misinformation based on linguistic patterns and content analysis.

''', unsafe_allow_html=True) else: st.markdown(f'''
✅ Authentic News {result["confidence"]:.1%}

This content appears to be legitimate based on professional writing style and factual consistency.

''', unsafe_allow_html=True) with col2: st.markdown('
', unsafe_allow_html=True) st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) st.markdown('
', unsafe_allow_html=True) # Attention Analysis st.markdown('
', unsafe_allow_html=True) st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) st.markdown('
', unsafe_allow_html=True) except Exception as e: st.markdown('
', unsafe_allow_html=True) st.error(f"Error: {str(e)}. Please try again or contact support.") st.markdown('
', unsafe_allow_html=True) else: st.markdown('
', unsafe_allow_html=True) st.error("Please enter a news article (at least 10 words) for analysis.") st.markdown('
', unsafe_allow_html=True) if __name__ == "__main__": main()