import streamlit as st import torch import pandas as pd import numpy as np from pathlib import Path import sys import plotly.express as px import plotly.graph_objects as go from transformers import BertTokenizer import nltk # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab') try: nltk.data.find('corpora/wordnet') except LookupError: nltk.download('wordnet') # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.models.hybrid_model import HybridFakeNewsDetector from src.config.config import * from src.data.preprocessor import TextPreprocessor # Page config is set in main app.py @st.cache_resource def load_model_and_tokenizer(): """Load the model and tokenizer (cached).""" # Initialize model model = HybridFakeNewsDetector( bert_model_name=BERT_MODEL_NAME, lstm_hidden_size=LSTM_HIDDEN_SIZE, lstm_num_layers=LSTM_NUM_LAYERS, dropout_rate=DROPOUT_RATE ) # Load trained weights state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) # Filter out unexpected keys model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} # Load the filtered state dict model.load_state_dict(filtered_state_dict, strict=False) model.eval() # Initialize tokenizer tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) return model, tokenizer @st.cache_resource def get_preprocessor(): """Get the text preprocessor (cached).""" return TextPreprocessor() def predict_news(text): """Predict if the given news is fake or real.""" # Get model, tokenizer, and preprocessor from cache model, tokenizer = load_model_and_tokenizer() preprocessor = get_preprocessor() # Preprocess text processed_text = preprocessor.preprocess_text(text) # Tokenize encoding = tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_SEQUENCE_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) # Get prediction with torch.no_grad(): outputs = model( encoding['input_ids'], encoding['attention_mask'] ) probabilities = torch.softmax(outputs['logits'], dim=1) prediction = torch.argmax(outputs['logits'], dim=1) attention_weights = outputs['attention_weights'] # Convert attention weights to numpy and get the first sequence attention_weights_np = attention_weights[0].cpu().numpy() return { 'prediction': prediction.item(), 'label': 'FAKE' if prediction.item() == 1 else 'REAL', 'confidence': torch.max(probabilities, dim=1)[0].item(), 'probabilities': { 'REAL': probabilities[0][0].item(), 'FAKE': probabilities[0][1].item() }, 'attention_weights': attention_weights_np } def plot_confidence(probabilities): """Plot prediction confidence.""" fig = go.Figure(data=[ go.Bar( x=list(probabilities.keys()), y=list(probabilities.values()), text=[f'{p:.2%}' for p in probabilities.values()], textposition='auto', ) ]) fig.update_layout( title='Prediction Confidence', xaxis_title='Class', yaxis_title='Probability', yaxis_range=[0, 1] ) return fig def plot_attention(text, attention_weights): """Plot attention weights.""" tokens = text.split() attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens # Ensure attention weights are in the correct format if isinstance(attention_weights, (list, np.ndarray)): attention_weights = np.array(attention_weights).flatten() # Format weights for display formatted_weights = [f'{float(w):.2f}' for w in attention_weights] fig = go.Figure(data=[ go.Bar( x=tokens, y=attention_weights, text=formatted_weights, textposition='auto', ) ]) fig.update_layout( title='Attention Weights', xaxis_title='Tokens', yaxis_title='Attention Weight', xaxis_tickangle=45 ) return fig def main(): # Main Container st.markdown('
', unsafe_allow_html=True) # Custom CSS with Poppins font st.markdown(""" """, unsafe_allow_html=True) # Header Section st.markdown('
', unsafe_allow_html=True) st.markdown('

📰 TruthCheck - Advanced Fake News Detection System

', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) # Main Content st.markdown('
', unsafe_allow_html=True) st.markdown('

This application uses a hybrid deep learning model (BERT + BiLSTM + Attention) to detect fake news articles. Enter a news article below to analyze it.

', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) # News Analysis Section st.markdown('
', unsafe_allow_html=True) st.markdown('

📋 News Analysis

', unsafe_allow_html=True) # Input Section st.markdown('
', unsafe_allow_html=True) news_text = st.text_area( "Enter the news article to analyze:", height=200, placeholder="Paste your news article here..." ) st.markdown('
', unsafe_allow_html=True) if st.button("Analyze"): if news_text: with st.spinner("Analyzing the news article..."): # Get prediction result = predict_news(news_text) # Display result col1, col2 = st.columns(2) with col1: st.markdown('
', unsafe_allow_html=True) st.markdown('

🔍 Prediction

', unsafe_allow_html=True) if result['label'] == 'FAKE': st.markdown(f'
🚨 Fake News Detected {result["confidence"]:.2%}
', unsafe_allow_html=True) else: st.markdown(f'
✅ Authentic News {result["confidence"]:.2%}
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) with col2: st.markdown('
', unsafe_allow_html=True) st.markdown('

📊 Confidence Scores

', unsafe_allow_html=True) st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) st.markdown('
', unsafe_allow_html=True) # Show attention visualization st.markdown('
', unsafe_allow_html=True) st.markdown('

👁️ Attention Analysis

', unsafe_allow_html=True) st.markdown('

The attention weights show which parts of the text the model focused on while making its prediction. Higher weights indicate more important tokens.

', unsafe_allow_html=True) st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) st.markdown('
', unsafe_allow_html=True) # Show model explanation st.markdown('
', unsafe_allow_html=True) st.markdown('

📝 Model Explanation

', unsafe_allow_html=True) if result['label'] == 'FAKE': st.markdown('

The model identified this as fake news based on:

  • Linguistic patterns typical of fake news
  • Inconsistencies in the content
  • Attention weights on suspicious phrases

', unsafe_allow_html=True) else: st.markdown('

The model identified this as real news based on:

  • Credible language patterns
  • Consistent information
  • Attention weights on factual statements

', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) else: st.warning("Please enter a news article to analyze.") # Footer st.markdown("---") st.markdown( '', unsafe_allow_html=True ) st.markdown('
', unsafe_allow_html=True) # Close main-container if __name__ == "__main__": main()