# import streamlit as st # import torch # import pandas as pd # import numpy as np # from pathlib import Path # import sys # import plotly.express as px # import plotly.graph_objects as go # from transformers import BertTokenizer # import nltk # # Download required NLTK data # try: # nltk.data.find('tokenizers/punkt') # except LookupError: # nltk.download('punkt') # try: # nltk.data.find('corpora/stopwords') # except LookupError: # nltk.download('stopwords') # try: # nltk.data.find('tokenizers/punkt_tab') # except LookupError: # nltk.download('punkt_tab') # try: # nltk.data.find('corpora/wordnet') # except LookupError: # nltk.download('wordnet') # # Add project root to Python path # project_root = Path(__file__).parent.parent # sys.path.append(str(project_root)) # from src.models.hybrid_model import HybridFakeNewsDetector # from src.config.config import * # from src.data.preprocessor import TextPreprocessor # # Page config is set in main app.py # @st.cache_resource # def load_model_and_tokenizer(): # """Load the model and tokenizer (cached).""" # # Initialize model # model = HybridFakeNewsDetector( # bert_model_name=BERT_MODEL_NAME, # lstm_hidden_size=LSTM_HIDDEN_SIZE, # lstm_num_layers=LSTM_NUM_LAYERS, # dropout_rate=DROPOUT_RATE # ) # # Load trained weights # state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) # # Filter out unexpected keys # model_state_dict = model.state_dict() # filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} # # Load the filtered state dict # model.load_state_dict(filtered_state_dict, strict=False) # model.eval() # # Initialize tokenizer # tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) # return model, tokenizer # @st.cache_resource # def get_preprocessor(): # """Get the text preprocessor (cached).""" # return TextPreprocessor() # def predict_news(text): # """Predict if the given news is fake or real.""" # # Get model, tokenizer, and preprocessor from cache # model, tokenizer = load_model_and_tokenizer() # preprocessor = get_preprocessor() # # Preprocess text # processed_text = preprocessor.preprocess_text(text) # # Tokenize # encoding = tokenizer.encode_plus( # processed_text, # add_special_tokens=True, # max_length=MAX_SEQUENCE_LENGTH, # padding='max_length', # truncation=True, # return_attention_mask=True, # return_tensors='pt' # ) # # Get prediction # with torch.no_grad(): # outputs = model( # encoding['input_ids'], # encoding['attention_mask'] # ) # probabilities = torch.softmax(outputs['logits'], dim=1) # prediction = torch.argmax(outputs['logits'], dim=1) # attention_weights = outputs['attention_weights'] # # Convert attention weights to numpy and get the first sequence # attention_weights_np = attention_weights[0].cpu().numpy() # return { # 'prediction': prediction.item(), # 'label': 'FAKE' if prediction.item() == 1 else 'REAL', # 'confidence': torch.max(probabilities, dim=1)[0].item(), # 'probabilities': { # 'REAL': probabilities[0][0].item(), # 'FAKE': probabilities[0][1].item() # }, # 'attention_weights': attention_weights_np # } # def plot_confidence(probabilities): # """Plot prediction confidence.""" # fig = go.Figure(data=[ # go.Bar( # x=list(probabilities.keys()), # y=list(probabilities.values()), # text=[f'{p:.2%}' for p in probabilities.values()], # textposition='auto', # ) # ]) # fig.update_layout( # title='Prediction Confidence', # xaxis_title='Class', # yaxis_title='Probability', # yaxis_range=[0, 1] # ) # return fig # def plot_attention(text, attention_weights): # """Plot attention weights.""" # tokens = text.split() # attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens # # Ensure attention weights are in the correct format # if isinstance(attention_weights, (list, np.ndarray)): # attention_weights = np.array(attention_weights).flatten() # # Format weights for display # formatted_weights = [f'{float(w):.2f}' for w in attention_weights] # fig = go.Figure(data=[ # go.Bar( # x=tokens, # y=attention_weights, # text=formatted_weights, # textposition='auto', # ) # ]) # fig.update_layout( # title='Attention Weights', # xaxis_title='Tokens', # yaxis_title='Attention Weight', # xaxis_tickangle=45 # ) # return fig # def main(): # st.title("📰 Fake News Detection System") # st.write(""" # This application uses a hybrid deep learning model (BERT + BiLSTM + Attention) # to detect fake news articles. Enter a news article below to analyze it. # """) # # Sidebar # st.sidebar.title("About") # st.sidebar.info(""" # The model combines: # - BERT for contextual embeddings # - BiLSTM for sequence modeling # - Attention mechanism for interpretability # """) # # Main content # st.header("News Analysis") # # Text input # news_text = st.text_area( # "Enter the news article to analyze:", # height=200, # placeholder="Paste your news article here..." # ) # if st.button("Analyze"): # if news_text: # with st.spinner("Analyzing the news article..."): # # Get prediction # result = predict_news(news_text) # # Display result # col1, col2 = st.columns(2) # with col1: # st.subheader("Prediction") # if result['label'] == 'FAKE': # st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})") # else: # st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})") # with col2: # st.subheader("Confidence Scores") # st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) # # Show attention visualization # st.subheader("Attention Analysis") # st.write(""" # The attention weights show which parts of the text the model focused on # while making its prediction. Higher weights indicate more important tokens. # """) # st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) # # Show model explanation # st.subheader("Model Explanation") # if result['label'] == 'FAKE': # st.write(""" # The model identified this as fake news based on: # - Linguistic patterns typical of fake news # - Inconsistencies in the content # - Attention weights on suspicious phrases # """) # else: # st.write(""" # The model identified this as real news based on: # - Credible language patterns # - Consistent information # - Attention weights on factual statements # """) # else: # st.warning("Please enter a news article to analyze.") # if __name__ == "__main__": # main() import streamlit as st import torch import pandas as pd import numpy as np from pathlib import Path import sys import plotly.express as px import plotly.graph_objects as go from transformers import BertTokenizer import nltk # Download required NLTK data try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab') try: nltk.data.find('corpora/wordnet') except LookupError: nltk.download('wordnet') # Add project root to Python path project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) from src.models.hybrid_model import HybridFakeNewsDetector from src.config.config import * from src.data.preprocessor import TextPreprocessor # REMOVED st.set_page_config() - This should only be called once in the main entry point # Custom CSS for modern styling st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model_and_tokenizer(): """Load the model and tokenizer (cached).""" model = HybridFakeNewsDetector( bert_model_name=BERT_MODEL_NAME, lstm_hidden_size=LSTM_HIDDEN_SIZE, lstm_num_layers=LSTM_NUM_LAYERS, dropout_rate=DROPOUT_RATE ) state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} model.load_state_dict(filtered_state_dict, strict=False) model.eval() tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) return model, tokenizer @st.cache_resource def get_preprocessor(): """Get the text preprocessor (cached).""" return TextPreprocessor() def predict_news(text): """Predict if the given news is fake or real.""" model, tokenizer = load_model_and_tokenizer() preprocessor = get_preprocessor() processed_text = preprocessor.preprocess_text(text) encoding = tokenizer.encode_plus( processed_text, add_special_tokens=True, max_length=MAX_SEQUENCE_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) with torch.no_grad(): outputs = model( encoding['input_ids'], encoding['attention_mask'] ) probabilities = torch.softmax(outputs['logits'], dim=1) prediction = torch.argmax(outputs['logits'], dim=1) attention_weights = outputs['attention_weights'] attention_weights_np = attention_weights[0].cpu().numpy() return { 'prediction': prediction.item(), 'label': 'FAKE' if prediction.item() == 1 else 'REAL', 'confidence': torch.max(probabilities, dim=1)[0].item(), 'probabilities': { 'REAL': probabilities[0][0].item(), 'FAKE': probabilities[0][1].item() }, 'attention_weights': attention_weights_np } def plot_confidence(probabilities): """Plot prediction confidence.""" fig = go.Figure(data=[ go.Bar( x=list(probabilities.keys()), y=list(probabilities.values()), text=[f'{p:.2%}' for p in probabilities.values()], textposition='auto', marker_color=['#22c55e', '#ef4444'], marker_line_color='rgba(0,0,0,0.1)', marker_line_width=1 ) ]) fig.update_layout( title={ 'text': 'Prediction Confidence', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18, 'family': 'Inter'} }, xaxis_title='Class', yaxis_title='Probability', yaxis_range=[0, 1], template='plotly_white', plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', font={'family': 'Inter'} ) return fig def plot_attention(text, attention_weights): """Plot attention weights.""" tokens = text.split() attention_weights = attention_weights[:len(tokens)] if isinstance(attention_weights, (list, np.ndarray)): attention_weights = np.array(attention_weights).flatten() formatted_weights = [f'{float(w):.2f}' for w in attention_weights] # Create color scale based on attention weights colors = ['rgba(102, 126, 234, ' + str(0.3 + 0.7 * (w / max(attention_weights))) + ')' for w in attention_weights] fig = go.Figure(data=[ go.Bar( x=tokens, y=attention_weights, text=formatted_weights, textposition='auto', marker_color=colors, marker_line_color='rgba(102, 126, 234, 0.8)', marker_line_width=1 ) ]) fig.update_layout( title={ 'text': 'Attention Weights Analysis', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 18, 'family': 'Inter'} }, xaxis_title='Tokens', yaxis_title='Attention Weight', xaxis_tickangle=45, template='plotly_white', plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', font={'family': 'Inter'} ) return fig def main(): # Hero Section st.markdown("""

🔍 TrueCheck

Advanced AI-powered fake news detection using cutting-edge deep learning technology. Get instant, accurate analysis of news articles with our hybrid BERT-BiLSTM model.

""", unsafe_allow_html=True) # Features Section st.markdown("""

Why Choose TrueCheck?

Our advanced AI model combines multiple technologies for superior accuracy

🤖

BERT Technology

Utilizes state-of-the-art BERT transformer for deep contextual understanding of news content

🧠

BiLSTM Processing

Bidirectional LSTM networks capture sequential patterns and dependencies in text structure

👁️

Attention Mechanism

Advanced attention layers provide interpretable insights into model decision-making process

""", unsafe_allow_html=True) # Main Content Section st.markdown("""

Analyze News Article

Paste any news article below and our AI will analyze it for authenticity. Get detailed insights including confidence scores and attention analysis.

""", unsafe_allow_html=True) # Input Section col1, col2, col3 = st.columns([1, 3, 1]) with col2: news_text = st.text_area( "", height=200, placeholder="📰 Paste your news article here for analysis...", key="news_input" ) analyze_button = st.button("🔍 Analyze Article", key="analyze_button") if analyze_button: if news_text: with st.spinner("🤖 Analyzing the news article..."): result = predict_news(news_text) # Results Section st.markdown('
', unsafe_allow_html=True) col1, col2 = st.columns([1, 1], gap="large") with col1: st.markdown("### 📊 Prediction Result") if result['label'] == 'FAKE': st.markdown(f'''
🔴 FAKE NEWS DETECTED
Confidence: {result["confidence"]:.2%}
''', unsafe_allow_html=True) else: st.markdown(f'''
🟢 AUTHENTIC NEWS
Confidence: {result["confidence"]:.2%}
''', unsafe_allow_html=True) with col2: st.markdown("### 📈 Confidence Breakdown") st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) st.markdown("### 🎯 Attention Analysis") st.markdown("""

The visualization below shows which words our AI model focused on while making its prediction. Darker colors indicate higher attention weights.

""", unsafe_allow_html=True) st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) st.markdown("### 🔍 Detailed Analysis") if result['label'] == 'FAKE': st.markdown("""

⚠️ Fake News Indicators

💡 Recommendation: Verify this information through multiple reliable sources before sharing.

""", unsafe_allow_html=True) else: st.markdown("""

✅ Authentic News Indicators

💡 Note: While likely authentic, always cross-reference important news from multiple sources.

""", unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) else: st.markdown('''
⚠️ Please enter a news article to analyze
''', unsafe_allow_html=True) # Footer st.markdown(""" """, unsafe_allow_html=True) if __name__ == "__main__": main()