Spaces:
Sleeping
Sleeping
# import streamlit as st | |
# import torch | |
# import pandas as pd | |
# import numpy as np | |
# from pathlib import Path | |
# import sys | |
# import plotly.express as px | |
# import plotly.graph_objects as go | |
# from transformers import BertTokenizer | |
# import nltk | |
# # Download required NLTK data | |
# try: | |
# nltk.data.find('tokenizers/punkt') | |
# except LookupError: | |
# nltk.download('punkt') | |
# try: | |
# nltk.data.find('corpora/stopwords') | |
# except LookupError: | |
# nltk.download('stopwords') | |
# try: | |
# nltk.data.find('tokenizers/punkt_tab') | |
# except LookupError: | |
# nltk.download('punkt_tab') | |
# try: | |
# nltk.data.find('corpora/wordnet') | |
# except LookupError: | |
# nltk.download('wordnet') | |
# # Add project root to Python path | |
# project_root = Path(__file__).parent.parent | |
# sys.path.append(str(project_root)) | |
# from src.models.hybrid_model import HybridFakeNewsDetector | |
# from src.config.config import * | |
# from src.data.preprocessor import TextPreprocessor | |
# # Page config is set in main app.py | |
# @st.cache_resource | |
# def load_model_and_tokenizer(): | |
# """Load the model and tokenizer (cached).""" | |
# # Initialize model | |
# model = HybridFakeNewsDetector( | |
# bert_model_name=BERT_MODEL_NAME, | |
# lstm_hidden_size=LSTM_HIDDEN_SIZE, | |
# lstm_num_layers=LSTM_NUM_LAYERS, | |
# dropout_rate=DROPOUT_RATE | |
# ) | |
# # Load trained weights | |
# state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) | |
# # Filter out unexpected keys | |
# model_state_dict = model.state_dict() | |
# filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | |
# # Load the filtered state dict | |
# model.load_state_dict(filtered_state_dict, strict=False) | |
# model.eval() | |
# # Initialize tokenizer | |
# tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
# return model, tokenizer | |
# @st.cache_resource | |
# def get_preprocessor(): | |
# """Get the text preprocessor (cached).""" | |
# return TextPreprocessor() | |
# def predict_news(text): | |
# """Predict if the given news is fake or real.""" | |
# # Get model, tokenizer, and preprocessor from cache | |
# model, tokenizer = load_model_and_tokenizer() | |
# preprocessor = get_preprocessor() | |
# # Preprocess text | |
# processed_text = preprocessor.preprocess_text(text) | |
# # Tokenize | |
# encoding = tokenizer.encode_plus( | |
# processed_text, | |
# add_special_tokens=True, | |
# max_length=MAX_SEQUENCE_LENGTH, | |
# padding='max_length', | |
# truncation=True, | |
# return_attention_mask=True, | |
# return_tensors='pt' | |
# ) | |
# # Get prediction | |
# with torch.no_grad(): | |
# outputs = model( | |
# encoding['input_ids'], | |
# encoding['attention_mask'] | |
# ) | |
# probabilities = torch.softmax(outputs['logits'], dim=1) | |
# prediction = torch.argmax(outputs['logits'], dim=1) | |
# attention_weights = outputs['attention_weights'] | |
# # Convert attention weights to numpy and get the first sequence | |
# attention_weights_np = attention_weights[0].cpu().numpy() | |
# return { | |
# 'prediction': prediction.item(), | |
# 'label': 'FAKE' if prediction.item() == 1 else 'REAL', | |
# 'confidence': torch.max(probabilities, dim=1)[0].item(), | |
# 'probabilities': { | |
# 'REAL': probabilities[0][0].item(), | |
# 'FAKE': probabilities[0][1].item() | |
# }, | |
# 'attention_weights': attention_weights_np | |
# } | |
# def plot_confidence(probabilities): | |
# """Plot prediction confidence.""" | |
# fig = go.Figure(data=[ | |
# go.Bar( | |
# x=list(probabilities.keys()), | |
# y=list(probabilities.values()), | |
# text=[f'{p:.2%}' for p in probabilities.values()], | |
# textposition='auto', | |
# ) | |
# ]) | |
# fig.update_layout( | |
# title='Prediction Confidence', | |
# xaxis_title='Class', | |
# yaxis_title='Probability', | |
# yaxis_range=[0, 1] | |
# ) | |
# return fig | |
# def plot_attention(text, attention_weights): | |
# """Plot attention weights.""" | |
# tokens = text.split() | |
# attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens | |
# # Ensure attention weights are in the correct format | |
# if isinstance(attention_weights, (list, np.ndarray)): | |
# attention_weights = np.array(attention_weights).flatten() | |
# # Format weights for display | |
# formatted_weights = [f'{float(w):.2f}' for w in attention_weights] | |
# fig = go.Figure(data=[ | |
# go.Bar( | |
# x=tokens, | |
# y=attention_weights, | |
# text=formatted_weights, | |
# textposition='auto', | |
# ) | |
# ]) | |
# fig.update_layout( | |
# title='Attention Weights', | |
# xaxis_title='Tokens', | |
# yaxis_title='Attention Weight', | |
# xaxis_tickangle=45 | |
# ) | |
# return fig | |
# def main(): | |
# st.title("📰 Fake News Detection System") | |
# st.write(""" | |
# This application uses a hybrid deep learning model (BERT + BiLSTM + Attention) | |
# to detect fake news articles. Enter a news article below to analyze it. | |
# """) | |
# # Sidebar | |
# st.sidebar.title("About") | |
# st.sidebar.info(""" | |
# The model combines: | |
# - BERT for contextual embeddings | |
# - BiLSTM for sequence modeling | |
# - Attention mechanism for interpretability | |
# """) | |
# # Main content | |
# st.header("News Analysis") | |
# # Text input | |
# news_text = st.text_area( | |
# "Enter the news article to analyze:", | |
# height=200, | |
# placeholder="Paste your news article here..." | |
# ) | |
# if st.button("Analyze"): | |
# if news_text: | |
# with st.spinner("Analyzing the news article..."): | |
# # Get prediction | |
# result = predict_news(news_text) | |
# # Display result | |
# col1, col2 = st.columns(2) | |
# with col1: | |
# st.subheader("Prediction") | |
# if result['label'] == 'FAKE': | |
# st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})") | |
# else: | |
# st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})") | |
# with col2: | |
# st.subheader("Confidence Scores") | |
# st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) | |
# # Show attention visualization | |
# st.subheader("Attention Analysis") | |
# st.write(""" | |
# The attention weights show which parts of the text the model focused on | |
# while making its prediction. Higher weights indicate more important tokens. | |
# """) | |
# st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) | |
# # Show model explanation | |
# st.subheader("Model Explanation") | |
# if result['label'] == 'FAKE': | |
# st.write(""" | |
# The model identified this as fake news based on: | |
# - Linguistic patterns typical of fake news | |
# - Inconsistencies in the content | |
# - Attention weights on suspicious phrases | |
# """) | |
# else: | |
# st.write(""" | |
# The model identified this as real news based on: | |
# - Credible language patterns | |
# - Consistent information | |
# - Attention weights on factual statements | |
# """) | |
# else: | |
# st.warning("Please enter a news article to analyze.") | |
# if __name__ == "__main__": | |
# main() | |
import streamlit as st | |
import torch | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
import sys | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from transformers import BertTokenizer | |
import nltk | |
# Download required NLTK data | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
try: | |
nltk.data.find('corpora/stopwords') | |
except LookupError: | |
nltk.download('stopwords') | |
try: | |
nltk.data.find('tokenizers/punkt_tab') | |
except LookupError: | |
nltk.download('punkt_tab') | |
try: | |
nltk.data.find('corpora/wordnet') | |
except LookupError: | |
nltk.download('wordnet') | |
# Add project root to Python path | |
project_root = Path(__file__).parent.parent | |
sys.path.append(str(project_root)) | |
from src.models.hybrid_model import HybridFakeNewsDetector | |
from src.config.config import * | |
from src.data.preprocessor import TextPreprocessor | |
def load_model_and_tokenizer(): | |
"""Load the model and tokenizer (cached).""" | |
model = HybridFakeNewsDetector( | |
bert_model_name=BERT_MODEL_NAME, | |
lstm_hidden_size=LSTM_HIDDEN_SIZE, | |
lstm_num_layers=LSTM_NUM_LAYERS, | |
dropout_rate=DROPOUT_RATE | |
) | |
state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu')) | |
model_state_dict = model.state_dict() | |
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | |
model.load_state_dict(filtered_state_dict, strict=False) | |
model.eval() | |
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME) | |
return model, tokenizer | |
def get_preprocessor(): | |
"""Get the text preprocessor (cached).""" | |
return TextPreprocessor() | |
def predict_news(text): | |
"""Predict if the given news is fake or real.""" | |
model, tokenizer = load_model_and_tokenizer() | |
preprocessor = get_preprocessor() | |
processed_text = preprocessor.preprocess_text(text) | |
encoding = tokenizer.encode_plus( | |
processed_text, | |
add_special_tokens=True, | |
max_length=MAX_SEQUENCE_LENGTH, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
with torch.no_grad(): | |
outputs = model( | |
encoding['input_ids'], | |
encoding['attention_mask'] | |
) | |
probabilities = torch.softmax(outputs['logits'], dim=1) | |
prediction = torch.argmax(outputs['logits'], dim=1) | |
attention_weights = outputs['attention_weights'] | |
attention_weights_np = attention_weights[0].cpu().numpy() | |
return { | |
'prediction': prediction.item(), | |
'label': 'FAKE' if prediction.item() == 1 else 'REAL', | |
'confidence': torch.max(probabilities, dim=1)[0].item(), | |
'probabilities': { | |
'REAL': probabilities[0][0].item(), | |
'FAKE': probabilities[0][1].item() | |
}, | |
'attention_weights': attention_weights_np | |
} | |
def plot_confidence(probabilities): | |
"""Plot prediction confidence.""" | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=list(probabilities.keys()), | |
y=list(probabilities.values()), | |
text=[f'{p:.2%}' for p in probabilities.values()], | |
textposition='auto', | |
marker_color=['#4B5EAA', '#FF6B6B'] | |
) | |
]) | |
fig.update_layout( | |
title='Prediction Confidence', | |
xaxis_title='Class', | |
yaxis_title='Probability', | |
yaxis_range=[0, 1], | |
template='plotly_white' | |
) | |
return fig | |
def plot_attention(text, attention_weights): | |
"""Plot attention weights.""" | |
tokens = text.split() | |
attention_weights = attention_weights[:len(tokens)] | |
if isinstance(attention_weights, (list, np.ndarray)): | |
attention_weights = np.array(attention_weights).flatten() | |
formatted_weights = [f'{float(w):.2f}' for w in attention_weights] | |
fig = go.Figure(data=[ | |
go.Bar( | |
x=tokens, | |
y=attention_weights, | |
text=formatted_weights, | |
textposition='auto', | |
marker_color='#4B5EAA' | |
) | |
]) | |
fig.update_layout( | |
title='Attention Weights', | |
xaxis_title='Tokens', | |
yaxis_title='Attention Weight', | |
xaxis_tickangle=45, | |
template='plotly_white' | |
) | |
return fig | |
def main(): | |
# Hero section | |
st.markdown(""" | |
<div class="hero-section"> | |
<div style="display: flex; align-items: center; gap: 2rem;"> | |
<div style="flex: 1;"> | |
<h1 style="font-size: 2.5rem; color: #333333;">TrueCheck</h1> | |
<p style="font-size: 1.2rem; color: #666666;"> | |
Detect fake news with our advanced AI-powered system using BERT, BiLSTM, and Attention mechanisms. | |
</p> | |
</div> | |
<div style="flex: 1;"> | |
<img src="https://img.freepik.com/free-vector/fake-news-concept-illustration_114360-3189.jpg" style="width: 100%; border-radius: 12px;" alt="Fake News Detection"> | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Sidebar info | |
st.sidebar.markdown("---") | |
st.sidebar.header("About TrueCheck") | |
st.sidebar.markdown(""" | |
<div style="font-size: 0.9rem; color: #666666;"> | |
<p>TrueCheck uses a hybrid deep learning model combining:</p> | |
<ul> | |
<li>BERT for contextual embeddings</li> | |
<li>BiLSTM for sequence modeling</li> | |
<li>Attention mechanism for interpretability</li> | |
</ul> | |
</div> | |
""", unsafe_allow_html=True) | |
# Main content | |
st.header("Analyze News") | |
news_text = st.text_area( | |
"Enter the news article to analyze:", | |
height=200, | |
placeholder="Paste your news article here..." | |
) | |
if st.button("Analyze", key="analyze_button"): | |
if news_text: | |
with st.spinner("Analyzing the news article..."): | |
result = predict_news(news_text) | |
col1, col2 = st.columns([1, 1], gap="large") | |
with col1: | |
st.markdown("### Prediction") | |
if result['label'] == 'FAKE': | |
st.markdown(f'<div class="flash-message error-message">🔴 This news is likely FAKE (Confidence: {result["confidence"]:.2%})</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="flash-message success-message">🟢 This news is likely REAL (Confidence: {result["confidence"]:.2%})</div>', unsafe_allow_html=True) | |
with col2: | |
st.markdown("### Confidence Scores") | |
st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True) | |
st.markdown("### Attention Analysis") | |
st.markdown(""" | |
<p style="color: #666666;"> | |
The attention weights show which parts of the text the model focused on while making its prediction. Higher weights indicate more important tokens. | |
</p> | |
""", unsafe_allow_html=True) | |
st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True) | |
st.markdown("### Model Explanation") | |
if result['label'] == 'FAKE': | |
st.markdown(""" | |
<div style="background-color: #F4F7FA; padding: 1rem; border-radius: 8px;"> | |
<p>The model identified this as fake news based on:</p> | |
<ul> | |
<li>Linguistic patterns typical of fake news</li> | |
<li>Inconsistencies in the content</li> | |
<li>Attention weights on suspicious phrases</li> | |
</ul> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.markdown(""" | |
<div style="background-color: #F4F7FA; padding: 1rem; border-radius: 8px;"> | |
<p>The model identified this as real news based on:</p> | |
<ul> | |
<li>Credible language patterns</li> | |
<li>Consistent information</li> | |
<li>Attention weights on factual statements</li> | |
</ul> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.markdown('<div class="flash-message error-message">Please enter a news article to analyze.</div>', unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |