import streamlit as st import torch import pandas as pd import numpy as np from pathlib import Path import sys import plotly.express as px import plotly.graph_objects as go from transformers import BertTokenizer import nltk # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab') try: nltk.data.find('corpora/wordnet') except LookupError: nltk.download('wordnet') # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.models.hybrid_model import HybridFakeNewsDetector from src.config.config import * from src.data.preprocessor import TextPreprocessor # Custom CSS with Poppins font and increased font sizes st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_tokenizer(): """Load the model and tokenizer (cached).""" model = HybridFakeNewsDetector( bert_model_name=BERT_MODEL_NAME, lstm_hidden_size=LSTM_HIDDEN_SIZE, lstm_num_layers=LSTM_NUM_LAYERS, dropout_rate=DROPOUT_RATE ) state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} model.load_state_dict(filtered_state_dict, strict=False) model.eval() tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) return model, tokenizer @st.cache_resource def get_preprocessor(): """Get the text preprocessor (cached).""" return TextPreprocessor() def predict_news(text): """Predict if the given news is fake or real.""" model, tokenizer = load_model_and_tokenizer() preprocessor = get_preprocessor() processed_text = preprocessor.preprocess_text(text) encoding = tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_SEQUENCE_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): outputs = model( encoding['input_ids'], encoding['attention_mask'] ) probabilities = torch.softmax(outputs['logits'], dim=1) prediction = torch.argmax(outputs['logits'], dim=1) attention_weights = outputs['attention_weights'] attention_weights_np = attention_weights[0].cpu().numpy() return { 'prediction': prediction.item(), 'label': 'FAKE' if prediction.item() == 1 else 'REAL', 'confidence': torch.max(probabilities, dim=1)[0].item(), 'probabilities': { 'REAL': probabilities[0][0].item(), 'FAKE': probabilities[0][1].item() }, 'attention_weights': attention_weights_np } def plot_confidence(probabilities): """Plot prediction confidence with simplified styling.""" fig = go.Figure(data=[ go.Bar( x=list(probabilities.keys()), y=list(probabilities.values()), text=[f'{p:.1%}' for p in probabilities.values()], textposition='auto', marker=dict( color=['#10b981', '#ef4444'], line=dict(color='#ffffff', width=1), ), ) ]) fig.update_layout( title={'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, xaxis=dict(title='Classification', titlefont={'size': 12}, tickfont={'size': 10}), yaxis=dict(title='Probability', range=[0, 1], tickformat='.0%', titlefont={'size': 12}, tickfont={'size': 10}), template='plotly_white', height=300, margin=dict(t=60, b=60) ) return fig def plot_attention(text, attention_weights): """Plot attention weights with simplified styling.""" tokens = text.split()[:20] attention_weights = attention_weights[:len(tokens)] if isinstance(attention_weights, (list, np.ndarray)): attention_weights = np.array(attention_weights).flatten() normalized_weights = attention_weights / max(attention_weights) if max(attention_weights) > 0 else attention_weights colors = [f'rgba(99, 102, 241, {0.4 + 0.6 * float(w)})' for w in normalized_weights] fig = go.Figure(data=[ go.Bar( x=tokens, y=attention_weights, text=[f'{float(w):.3f}' for w in attention_weights], textposition='auto', marker=dict(color=colors), ) ]) fig.update_layout( title={'text': 'Attention Weights', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18}}, xaxis=dict(title='Words', tickangle=45, titlefont={'size': 12}, tickfont={'size': 10}), yaxis=dict(title='Attention Score', titlefont={'size': 12}, tickfont={'size': 10}), template='plotly_white', height=350, margin=dict(t=60, b=80) ) return fig def main(): # Header st.markdown("""
Verify news articles with our AI-powered tool, driven by BERT and BiLSTM for fast and accurate authenticity analysis.
TruthCheck uses a hybrid BERT-BiLSTM model to detect fake news with high accuracy. Paste an article below for instant analysis.
Our AI has identified this content as likely misinformation based on linguistic patterns and content analysis.
This content appears to be legitimate based on professional writing style and factual consistency.