Analyze News Article
Paste any news article below and our AI will analyze it for authenticity. Get detailed insights including confidence scores and attention analysis.
# import streamlit as st # import torch # import pandas as pd # import numpy as np # from pathlib import Path # import sys # import plotly.express as px # import plotly.graph_objects as go # from transformers import BertTokenizer # import nltk # # Download required NLTK data # try: # nltk.data.find('tokenizers/punkt') # except LookupError: # nltk.download('punkt') # try: # nltk.data.find('corpora/stopwords') # except LookupError: # nltk.download('stopwords') # try: # nltk.data.find('tokenizers/punkt_tab') # except LookupError: # nltk.download('punkt_tab') # try: # nltk.data.find('corpora/wordnet') # except LookupError: # nltk.download('wordnet') # # Add project root to Python path # project_root = Path(__file__).parent.parent # sys.path.append(str(project_root)) # from src.models.hybrid_model import HybridFakeNewsDetector # from src.config.config import * # from src.data.preprocessor import TextPreprocessor # # Page config is set in main app.py # @st.cache_resource # def load_model_and_tokenizer(): # """Load the model and tokenizer (cached).""" # # Initialize model # model = HybridFakeNewsDetector( # bert_model_name=BERT_MODEL_NAME, # lstm_hidden_size=LSTM_HIDDEN_SIZE, # lstm_num_layers=LSTM_NUM_LAYERS, # dropout_rate=DROPOUT_RATE # ) # # Load trained weights # state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) # # Filter out unexpected keys # model_state_dict = model.state_dict() # filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} # # Load the filtered state dict # model.load_state_dict(filtered_state_dict, strict=False) # model.eval() # # Initialize tokenizer # tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) # return model, tokenizer # @st.cache_resource # def get_preprocessor(): # """Get the text preprocessor (cached).""" # return TextPreprocessor() # def predict_news(text): # """Predict if the given news is fake or real.""" # # Get model, tokenizer, and preprocessor from cache # model, tokenizer = load_model_and_tokenizer() # preprocessor = get_preprocessor() # # Preprocess text # processed_text = preprocessor.preprocess_text(text) # # Tokenize # encoding = tokenizer.encode_plus( # processed_text, # add_special_tokens=True, # max_length=MAX_SEQUENCE_LENGTH, # padding='max_length', # truncation=True, # return_attention_mask=True, # return_tensors='pt' # ) # # Get prediction # with torch.no_grad(): # outputs = model( # encoding['input_ids'], # encoding['attention_mask'] # ) # probabilities = torch.softmax(outputs['logits'], dim=1) # prediction = torch.argmax(outputs['logits'], dim=1) # attention_weights = outputs['attention_weights'] # # Convert attention weights to numpy and get the first sequence # attention_weights_np = attention_weights[0].cpu().numpy() # return { # 'prediction': prediction.item(), # 'label': 'FAKE' if prediction.item() == 1 else 'REAL', # 'confidence': torch.max(probabilities, dim=1)[0].item(), # 'probabilities': { # 'REAL': probabilities[0][0].item(), # 'FAKE': probabilities[0][1].item() # }, # 'attention_weights': attention_weights_np # } # def plot_confidence(probabilities): # """Plot prediction confidence.""" # fig = go.Figure(data=[ # go.Bar( # x=list(probabilities.keys()), # y=list(probabilities.values()), # text=[f'{p:.2%}' for p in probabilities.values()], # textposition='auto', # ) # ]) # fig.update_layout( # title='Prediction Confidence', # xaxis_title='Class', # yaxis_title='Probability', # yaxis_range=[0, 1] # ) # return fig # def plot_attention(text, attention_weights): # """Plot attention weights.""" # tokens = text.split() # attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens # # Ensure attention weights are in the correct format # if isinstance(attention_weights, (list, np.ndarray)): # attention_weights = np.array(attention_weights).flatten() # # Format weights for display # formatted_weights = [f'{float(w):.2f}' for w in attention_weights] # fig = go.Figure(data=[ # go.Bar( # x=tokens, # y=attention_weights, # text=formatted_weights, # textposition='auto', # ) # ]) # fig.update_layout( # title='Attention Weights', # xaxis_title='Tokens', # yaxis_title='Attention Weight', # xaxis_tickangle=45 # ) # return fig # def main(): # st.title("📰 Fake News Detection System") # st.write(""" # This application uses a hybrid deep learning model (BERT + BiLSTM + Attention) # to detect fake news articles. Enter a news article below to analyze it. # """) # # Sidebar # st.sidebar.title("About") # st.sidebar.info(""" # The model combines: # - BERT for contextual embeddings # - BiLSTM for sequence modeling # - Attention mechanism for interpretability # """) # # Main content # st.header("News Analysis") # # Text input # news_text = st.text_area( # "Enter the news article to analyze:", # height=200, # placeholder="Paste your news article here..." # ) # if st.button("Analyze"): # if news_text: # with st.spinner("Analyzing the news article..."): # # Get prediction # result = predict_news(news_text) # # Display result # col1, col2 = st.columns(2) # with col1: # st.subheader("Prediction") # if result['label'] == 'FAKE': # st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})") # else: # st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})") # with col2: # st.subheader("Confidence Scores") # st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) # # Show attention visualization # st.subheader("Attention Analysis") # st.write(""" # The attention weights show which parts of the text the model focused on # while making its prediction. Higher weights indicate more important tokens. # """) # st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) # # Show model explanation # st.subheader("Model Explanation") # if result['label'] == 'FAKE': # st.write(""" # The model identified this as fake news based on: # - Linguistic patterns typical of fake news # - Inconsistencies in the content # - Attention weights on suspicious phrases # """) # else: # st.write(""" # The model identified this as real news based on: # - Credible language patterns # - Consistent information # - Attention weights on factual statements # """) # else: # st.warning("Please enter a news article to analyze.") # if __name__ == "__main__": # main() import streamlit as st import torch import pandas as pd import numpy as np from pathlib import Path import sys import plotly.express as px import plotly.graph_objects as go from transformers import BertTokenizer import nltk # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab') try: nltk.data.find('corpora/wordnet') except LookupError: nltk.download('wordnet') # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.models.hybrid_model import HybridFakeNewsDetector from src.config.config import * from src.data.preprocessor import TextPreprocessor # REMOVED st.set_page_config() - This should only be called once in the main entry point # Custom CSS for modern styling st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_tokenizer(): """Load the model and tokenizer (cached).""" model = HybridFakeNewsDetector( bert_model_name=BERT_MODEL_NAME, lstm_hidden_size=LSTM_HIDDEN_SIZE, lstm_num_layers=LSTM_NUM_LAYERS, dropout_rate=DROPOUT_RATE ) state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} model.load_state_dict(filtered_state_dict, strict=False) model.eval() tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) return model, tokenizer @st.cache_resource def get_preprocessor(): """Get the text preprocessor (cached).""" return TextPreprocessor() def predict_news(text): """Predict if the given news is fake or real.""" model, tokenizer = load_model_and_tokenizer() preprocessor = get_preprocessor() processed_text = preprocessor.preprocess_text(text) encoding = tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_SEQUENCE_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): outputs = model( encoding['input_ids'], encoding['attention_mask'] ) probabilities = torch.softmax(outputs['logits'], dim=1) prediction = torch.argmax(outputs['logits'], dim=1) attention_weights = outputs['attention_weights'] attention_weights_np = attention_weights[0].cpu().numpy() return { 'prediction': prediction.item(), 'label': 'FAKE' if prediction.item() == 1 else 'REAL', 'confidence': torch.max(probabilities, dim=1)[0].item(), 'probabilities': { 'REAL': probabilities[0][0].item(), 'FAKE': probabilities[0][1].item() }, 'attention_weights': attention_weights_np } def plot_confidence(probabilities): """Plot prediction confidence.""" fig = go.Figure(data=[ go.Bar( x=list(probabilities.keys()), y=list(probabilities.values()), text=[f'{p:.2%}' for p in probabilities.values()], textposition='auto', marker_color=['#22c55e', '#ef4444'], marker_line_color='rgba(0,0,0,0.1)', marker_line_width=1 ) ]) fig.update_layout( title={ 'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18, 'family': 'Inter'} }, xaxis_title='Class', yaxis_title='Probability', yaxis_range=[0, 1], template='plotly_white', plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', font={'family': 'Inter'} ) return fig def plot_attention(text, attention_weights): """Plot attention weights.""" tokens = text.split() attention_weights = attention_weights[:len(tokens)] if isinstance(attention_weights, (list, np.ndarray)): attention_weights = np.array(attention_weights).flatten() formatted_weights = [f'{float(w):.2f}' for w in attention_weights] # Create color scale based on attention weights colors = ['rgba(102, 126, 234, ' + str(0.3 + 0.7 * (w / max(attention_weights))) + ')' for w in attention_weights] fig = go.Figure(data=[ go.Bar( x=tokens, y=attention_weights, text=formatted_weights, textposition='auto', marker_color=colors, marker_line_color='rgba(102, 126, 234, 0.8)', marker_line_width=1 ) ]) fig.update_layout( title={ 'text': 'Attention Weights Analysis', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18, 'family': 'Inter'} }, xaxis_title='Tokens', yaxis_title='Attention Weight', xaxis_tickangle=45, template='plotly_white', plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', font={'family': 'Inter'} ) return fig def main(): # Hero Section st.markdown("""
Advanced AI-powered fake news detection using cutting-edge deep learning technology. Get instant, accurate analysis of news articles with our hybrid BERT-BiLSTM model.
Our advanced AI model combines multiple technologies for superior accuracy
Utilizes state-of-the-art BERT transformer for deep contextual understanding of news content
Bidirectional LSTM networks capture sequential patterns and dependencies in text structure
Advanced attention layers provide interpretable insights into model decision-making process
Paste any news article below and our AI will analyze it for authenticity. Get detailed insights including confidence scores and attention analysis.
The visualization below shows which words our AI model focused on while making its prediction. Darker colors indicate higher attention weights.
""", unsafe_allow_html=True) st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) st.markdown("### 🔍 Detailed Analysis") if result['label'] == 'FAKE': st.markdown("""💡 Recommendation: Verify this information through multiple reliable sources before sharing.
💡 Note: While likely authentic, always cross-reference important news from multiple sources.