Files changed (1) hide show
  1. src/app.py +317 -84
src/app.py CHANGED
@@ -1,3 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  import pandas as pd
@@ -35,33 +278,21 @@ from src.models.hybrid_model import HybridFakeNewsDetector
35
  from src.config.config import *
36
  from src.data.preprocessor import TextPreprocessor
37
 
38
- # Page config is set in main app.py
39
-
40
  @st.cache_resource
41
  def load_model_and_tokenizer():
42
  """Load the model and tokenizer (cached)."""
43
- # Initialize model
44
  model = HybridFakeNewsDetector(
45
  bert_model_name=BERT_MODEL_NAME,
46
  lstm_hidden_size=LSTM_HIDDEN_SIZE,
47
  lstm_num_layers=LSTM_NUM_LAYERS,
48
  dropout_rate=DROPOUT_RATE
49
  )
50
-
51
- # Load trained weights
52
  state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
53
-
54
- # Filter out unexpected keys
55
  model_state_dict = model.state_dict()
56
  filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
57
-
58
- # Load the filtered state dict
59
  model.load_state_dict(filtered_state_dict, strict=False)
60
  model.eval()
61
-
62
- # Initialize tokenizer
63
  tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
64
-
65
  return model, tokenizer
66
 
67
  @st.cache_resource
@@ -71,14 +302,9 @@ def get_preprocessor():
71
 
72
  def predict_news(text):
73
  """Predict if the given news is fake or real."""
74
- # Get model, tokenizer, and preprocessor from cache
75
  model, tokenizer = load_model_and_tokenizer()
76
  preprocessor = get_preprocessor()
77
-
78
- # Preprocess text
79
  processed_text = preprocessor.preprocess_text(text)
80
-
81
- # Tokenize
82
  encoding = tokenizer.encode_plus(
83
  processed_text,
84
  add_special_tokens=True,
@@ -88,8 +314,6 @@ def predict_news(text):
88
  return_attention_mask=True,
89
  return_tensors='pt'
90
  )
91
-
92
- # Get prediction
93
  with torch.no_grad():
94
  outputs = model(
95
  encoding['input_ids'],
@@ -98,10 +322,7 @@ def predict_news(text):
98
  probabilities = torch.softmax(outputs['logits'], dim=1)
99
  prediction = torch.argmax(outputs['logits'], dim=1)
100
  attention_weights = outputs['attention_weights']
101
-
102
- # Convert attention weights to numpy and get the first sequence
103
  attention_weights_np = attention_weights[0].cpu().numpy()
104
-
105
  return {
106
  'prediction': prediction.item(),
107
  'label': 'FAKE' if prediction.item() == 1 else 'REAL',
@@ -121,121 +342,133 @@ def plot_confidence(probabilities):
121
  y=list(probabilities.values()),
122
  text=[f'{p:.2%}' for p in probabilities.values()],
123
  textposition='auto',
 
124
  )
125
  ])
126
-
127
  fig.update_layout(
128
  title='Prediction Confidence',
129
  xaxis_title='Class',
130
  yaxis_title='Probability',
131
- yaxis_range=[0, 1]
 
132
  )
133
-
134
  return fig
135
 
136
  def plot_attention(text, attention_weights):
137
  """Plot attention weights."""
138
  tokens = text.split()
139
- attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens
140
-
141
- # Ensure attention weights are in the correct format
142
  if isinstance(attention_weights, (list, np.ndarray)):
143
  attention_weights = np.array(attention_weights).flatten()
144
-
145
- # Format weights for display
146
  formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
147
-
148
  fig = go.Figure(data=[
149
  go.Bar(
150
  x=tokens,
151
  y=attention_weights,
152
  text=formatted_weights,
153
  textposition='auto',
 
154
  )
155
  ])
156
-
157
  fig.update_layout(
158
  title='Attention Weights',
159
  xaxis_title='Tokens',
160
  yaxis_title='Attention Weight',
161
- xaxis_tickangle=45
 
162
  )
163
-
164
  return fig
165
 
166
  def main():
167
- st.title("📰 Fake News Detection System")
168
- st.write("""
169
- This application uses a hybrid deep learning model (BERT + BiLSTM + Attention)
170
- to detect fake news articles. Enter a news article below to analyze it.
171
- """)
172
-
173
- # Sidebar
174
- st.sidebar.title("About")
175
- st.sidebar.info("""
176
-
177
- The model combines:
178
- - BERT for contextual embeddings
179
- - BiLSTM for sequence modeling
180
- - Attention mechanism for interpretability
181
- """)
182
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Main content
184
- st.header("News Analysis")
185
-
186
- # Text input
187
  news_text = st.text_area(
188
  "Enter the news article to analyze:",
189
  height=200,
190
  placeholder="Paste your news article here..."
191
  )
192
 
193
- if st.button("Analyze"):
194
  if news_text:
195
  with st.spinner("Analyzing the news article..."):
196
- # Get prediction
197
  result = predict_news(news_text)
198
-
199
- # Display result
200
- col1, col2 = st.columns(2)
201
 
202
  with col1:
203
- st.subheader("Prediction")
204
  if result['label'] == 'FAKE':
205
- st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})")
206
  else:
207
- st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})")
208
 
209
  with col2:
210
- st.subheader("Confidence Scores")
211
  st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
212
 
213
- # Show attention visualization
214
- st.subheader("Attention Analysis")
215
- st.write("""
216
- The attention weights show which parts of the text the model focused on
217
- while making its prediction. Higher weights indicate more important tokens.
218
- """)
219
  st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
220
 
221
- # Show model explanation
222
- st.subheader("Model Explanation")
223
  if result['label'] == 'FAKE':
224
- st.write("""
225
- The model identified this as fake news based on:
226
- - Linguistic patterns typical of fake news
227
- - Inconsistencies in the content
228
- - Attention weights on suspicious phrases
229
- """)
 
 
 
 
230
  else:
231
- st.write("""
232
- The model identified this as real news based on:
233
- - Credible language patterns
234
- - Consistent information
235
- - Attention weights on factual statements
236
- """)
 
 
 
 
237
  else:
238
- st.warning("Please enter a news article to analyze.")
239
 
240
  if __name__ == "__main__":
241
- main()
 
1
+ # import streamlit as st
2
+ # import torch
3
+ # import pandas as pd
4
+ # import numpy as np
5
+ # from pathlib import Path
6
+ # import sys
7
+ # import plotly.express as px
8
+ # import plotly.graph_objects as go
9
+ # from transformers import BertTokenizer
10
+ # import nltk
11
+
12
+ # # Download required NLTK data
13
+ # try:
14
+ # nltk.data.find('tokenizers/punkt')
15
+ # except LookupError:
16
+ # nltk.download('punkt')
17
+ # try:
18
+ # nltk.data.find('corpora/stopwords')
19
+ # except LookupError:
20
+ # nltk.download('stopwords')
21
+ # try:
22
+ # nltk.data.find('tokenizers/punkt_tab')
23
+ # except LookupError:
24
+ # nltk.download('punkt_tab')
25
+ # try:
26
+ # nltk.data.find('corpora/wordnet')
27
+ # except LookupError:
28
+ # nltk.download('wordnet')
29
+
30
+ # # Add project root to Python path
31
+ # project_root = Path(__file__).parent.parent
32
+ # sys.path.append(str(project_root))
33
+
34
+ # from src.models.hybrid_model import HybridFakeNewsDetector
35
+ # from src.config.config import *
36
+ # from src.data.preprocessor import TextPreprocessor
37
+
38
+ # # Page config is set in main app.py
39
+
40
+ # @st.cache_resource
41
+ # def load_model_and_tokenizer():
42
+ # """Load the model and tokenizer (cached)."""
43
+ # # Initialize model
44
+ # model = HybridFakeNewsDetector(
45
+ # bert_model_name=BERT_MODEL_NAME,
46
+ # lstm_hidden_size=LSTM_HIDDEN_SIZE,
47
+ # lstm_num_layers=LSTM_NUM_LAYERS,
48
+ # dropout_rate=DROPOUT_RATE
49
+ # )
50
+
51
+ # # Load trained weights
52
+ # state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
53
+
54
+ # # Filter out unexpected keys
55
+ # model_state_dict = model.state_dict()
56
+ # filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
57
+
58
+ # # Load the filtered state dict
59
+ # model.load_state_dict(filtered_state_dict, strict=False)
60
+ # model.eval()
61
+
62
+ # # Initialize tokenizer
63
+ # tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
64
+
65
+ # return model, tokenizer
66
+
67
+ # @st.cache_resource
68
+ # def get_preprocessor():
69
+ # """Get the text preprocessor (cached)."""
70
+ # return TextPreprocessor()
71
+
72
+ # def predict_news(text):
73
+ # """Predict if the given news is fake or real."""
74
+ # # Get model, tokenizer, and preprocessor from cache
75
+ # model, tokenizer = load_model_and_tokenizer()
76
+ # preprocessor = get_preprocessor()
77
+
78
+ # # Preprocess text
79
+ # processed_text = preprocessor.preprocess_text(text)
80
+
81
+ # # Tokenize
82
+ # encoding = tokenizer.encode_plus(
83
+ # processed_text,
84
+ # add_special_tokens=True,
85
+ # max_length=MAX_SEQUENCE_LENGTH,
86
+ # padding='max_length',
87
+ # truncation=True,
88
+ # return_attention_mask=True,
89
+ # return_tensors='pt'
90
+ # )
91
+
92
+ # # Get prediction
93
+ # with torch.no_grad():
94
+ # outputs = model(
95
+ # encoding['input_ids'],
96
+ # encoding['attention_mask']
97
+ # )
98
+ # probabilities = torch.softmax(outputs['logits'], dim=1)
99
+ # prediction = torch.argmax(outputs['logits'], dim=1)
100
+ # attention_weights = outputs['attention_weights']
101
+
102
+ # # Convert attention weights to numpy and get the first sequence
103
+ # attention_weights_np = attention_weights[0].cpu().numpy()
104
+
105
+ # return {
106
+ # 'prediction': prediction.item(),
107
+ # 'label': 'FAKE' if prediction.item() == 1 else 'REAL',
108
+ # 'confidence': torch.max(probabilities, dim=1)[0].item(),
109
+ # 'probabilities': {
110
+ # 'REAL': probabilities[0][0].item(),
111
+ # 'FAKE': probabilities[0][1].item()
112
+ # },
113
+ # 'attention_weights': attention_weights_np
114
+ # }
115
+
116
+ # def plot_confidence(probabilities):
117
+ # """Plot prediction confidence."""
118
+ # fig = go.Figure(data=[
119
+ # go.Bar(
120
+ # x=list(probabilities.keys()),
121
+ # y=list(probabilities.values()),
122
+ # text=[f'{p:.2%}' for p in probabilities.values()],
123
+ # textposition='auto',
124
+ # )
125
+ # ])
126
+
127
+ # fig.update_layout(
128
+ # title='Prediction Confidence',
129
+ # xaxis_title='Class',
130
+ # yaxis_title='Probability',
131
+ # yaxis_range=[0, 1]
132
+ # )
133
+
134
+ # return fig
135
+
136
+ # def plot_attention(text, attention_weights):
137
+ # """Plot attention weights."""
138
+ # tokens = text.split()
139
+ # attention_weights = attention_weights[:len(tokens)] # Truncate to match tokens
140
+
141
+ # # Ensure attention weights are in the correct format
142
+ # if isinstance(attention_weights, (list, np.ndarray)):
143
+ # attention_weights = np.array(attention_weights).flatten()
144
+
145
+ # # Format weights for display
146
+ # formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
147
+
148
+ # fig = go.Figure(data=[
149
+ # go.Bar(
150
+ # x=tokens,
151
+ # y=attention_weights,
152
+ # text=formatted_weights,
153
+ # textposition='auto',
154
+ # )
155
+ # ])
156
+
157
+ # fig.update_layout(
158
+ # title='Attention Weights',
159
+ # xaxis_title='Tokens',
160
+ # yaxis_title='Attention Weight',
161
+ # xaxis_tickangle=45
162
+ # )
163
+
164
+ # return fig
165
+
166
+ # def main():
167
+ # st.title("📰 Fake News Detection System")
168
+ # st.write("""
169
+ # This application uses a hybrid deep learning model (BERT + BiLSTM + Attention)
170
+ # to detect fake news articles. Enter a news article below to analyze it.
171
+ # """)
172
+
173
+ # # Sidebar
174
+ # st.sidebar.title("About")
175
+ # st.sidebar.info("""
176
+
177
+ # The model combines:
178
+ # - BERT for contextual embeddings
179
+ # - BiLSTM for sequence modeling
180
+ # - Attention mechanism for interpretability
181
+ # """)
182
+
183
+ # # Main content
184
+ # st.header("News Analysis")
185
+
186
+ # # Text input
187
+ # news_text = st.text_area(
188
+ # "Enter the news article to analyze:",
189
+ # height=200,
190
+ # placeholder="Paste your news article here..."
191
+ # )
192
+
193
+ # if st.button("Analyze"):
194
+ # if news_text:
195
+ # with st.spinner("Analyzing the news article..."):
196
+ # # Get prediction
197
+ # result = predict_news(news_text)
198
+
199
+ # # Display result
200
+ # col1, col2 = st.columns(2)
201
+
202
+ # with col1:
203
+ # st.subheader("Prediction")
204
+ # if result['label'] == 'FAKE':
205
+ # st.error(f"🔴 This news is likely FAKE (Confidence: {result['confidence']:.2%})")
206
+ # else:
207
+ # st.success(f"🟢 This news is likely REAL (Confidence: {result['confidence']:.2%})")
208
+
209
+ # with col2:
210
+ # st.subheader("Confidence Scores")
211
+ # st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
212
+
213
+ # # Show attention visualization
214
+ # st.subheader("Attention Analysis")
215
+ # st.write("""
216
+ # The attention weights show which parts of the text the model focused on
217
+ # while making its prediction. Higher weights indicate more important tokens.
218
+ # """)
219
+ # st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
220
+
221
+ # # Show model explanation
222
+ # st.subheader("Model Explanation")
223
+ # if result['label'] == 'FAKE':
224
+ # st.write("""
225
+ # The model identified this as fake news based on:
226
+ # - Linguistic patterns typical of fake news
227
+ # - Inconsistencies in the content
228
+ # - Attention weights on suspicious phrases
229
+ # """)
230
+ # else:
231
+ # st.write("""
232
+ # The model identified this as real news based on:
233
+ # - Credible language patterns
234
+ # - Consistent information
235
+ # - Attention weights on factual statements
236
+ # """)
237
+ # else:
238
+ # st.warning("Please enter a news article to analyze.")
239
+
240
+ # if __name__ == "__main__":
241
+ # main()
242
+
243
+
244
  import streamlit as st
245
  import torch
246
  import pandas as pd
 
278
  from src.config.config import *
279
  from src.data.preprocessor import TextPreprocessor
280
 
 
 
281
  @st.cache_resource
282
  def load_model_and_tokenizer():
283
  """Load the model and tokenizer (cached)."""
 
284
  model = HybridFakeNewsDetector(
285
  bert_model_name=BERT_MODEL_NAME,
286
  lstm_hidden_size=LSTM_HIDDEN_SIZE,
287
  lstm_num_layers=LSTM_NUM_LAYERS,
288
  dropout_rate=DROPOUT_RATE
289
  )
 
 
290
  state_dict = torch.load(SAVED_MODELS_DIR / "final_model.pt", map_location=torch.device('cpu'))
 
 
291
  model_state_dict = model.state_dict()
292
  filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
 
 
293
  model.load_state_dict(filtered_state_dict, strict=False)
294
  model.eval()
 
 
295
  tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
 
296
  return model, tokenizer
297
 
298
  @st.cache_resource
 
302
 
303
  def predict_news(text):
304
  """Predict if the given news is fake or real."""
 
305
  model, tokenizer = load_model_and_tokenizer()
306
  preprocessor = get_preprocessor()
 
 
307
  processed_text = preprocessor.preprocess_text(text)
 
 
308
  encoding = tokenizer.encode_plus(
309
  processed_text,
310
  add_special_tokens=True,
 
314
  return_attention_mask=True,
315
  return_tensors='pt'
316
  )
 
 
317
  with torch.no_grad():
318
  outputs = model(
319
  encoding['input_ids'],
 
322
  probabilities = torch.softmax(outputs['logits'], dim=1)
323
  prediction = torch.argmax(outputs['logits'], dim=1)
324
  attention_weights = outputs['attention_weights']
 
 
325
  attention_weights_np = attention_weights[0].cpu().numpy()
 
326
  return {
327
  'prediction': prediction.item(),
328
  'label': 'FAKE' if prediction.item() == 1 else 'REAL',
 
342
  y=list(probabilities.values()),
343
  text=[f'{p:.2%}' for p in probabilities.values()],
344
  textposition='auto',
345
+ marker_color=['#4B5EAA', '#FF6B6B']
346
  )
347
  ])
 
348
  fig.update_layout(
349
  title='Prediction Confidence',
350
  xaxis_title='Class',
351
  yaxis_title='Probability',
352
+ yaxis_range=[0, 1],
353
+ template='plotly_white'
354
  )
 
355
  return fig
356
 
357
  def plot_attention(text, attention_weights):
358
  """Plot attention weights."""
359
  tokens = text.split()
360
+ attention_weights = attention_weights[:len(tokens)]
 
 
361
  if isinstance(attention_weights, (list, np.ndarray)):
362
  attention_weights = np.array(attention_weights).flatten()
 
 
363
  formatted_weights = [f'{float(w):.2f}' for w in attention_weights]
 
364
  fig = go.Figure(data=[
365
  go.Bar(
366
  x=tokens,
367
  y=attention_weights,
368
  text=formatted_weights,
369
  textposition='auto',
370
+ marker_color='#4B5EAA'
371
  )
372
  ])
 
373
  fig.update_layout(
374
  title='Attention Weights',
375
  xaxis_title='Tokens',
376
  yaxis_title='Attention Weight',
377
+ xaxis_tickangle=45,
378
+ template='plotly_white'
379
  )
 
380
  return fig
381
 
382
  def main():
383
+ # Hero section
384
+ st.markdown("""
385
+ <div class="hero-section">
386
+ <div style="display: flex; align-items: center; gap: 2rem;">
387
+ <div style="flex: 1;">
388
+ <h1 style="font-size: 2.5rem; color: #333333;">TrueCheck</h1>
389
+ <p style="font-size: 1.2rem; color: #666666;">
390
+ Detect fake news with our advanced AI-powered system using BERT, BiLSTM, and Attention mechanisms.
391
+ </p>
392
+ </div>
393
+ <div style="flex: 1;">
394
+ <img src="https://img.freepik.com/free-vector/fake-news-concept-illustration_114360-3189.jpg" style="width: 100%; border-radius: 12px;" alt="Fake News Detection">
395
+ </div>
396
+ </div>
397
+ </div>
398
+ """, unsafe_allow_html=True)
399
+
400
+ # Sidebar info
401
+ st.sidebar.markdown("---")
402
+ st.sidebar.header("About TrueCheck")
403
+ st.sidebar.markdown("""
404
+ <div style="font-size: 0.9rem; color: #666666;">
405
+ <p>TrueCheck uses a hybrid deep learning model combining:</p>
406
+ <ul>
407
+ <li>BERT for contextual embeddings</li>
408
+ <li>BiLSTM for sequence modeling</li>
409
+ <li>Attention mechanism for interpretability</li>
410
+ </ul>
411
+ </div>
412
+ """, unsafe_allow_html=True)
413
+
414
  # Main content
415
+ st.header("Analyze News")
 
 
416
  news_text = st.text_area(
417
  "Enter the news article to analyze:",
418
  height=200,
419
  placeholder="Paste your news article here..."
420
  )
421
 
422
+ if st.button("Analyze", key="analyze_button"):
423
  if news_text:
424
  with st.spinner("Analyzing the news article..."):
 
425
  result = predict_news(news_text)
426
+ col1, col2 = st.columns([1, 1], gap="large")
 
 
427
 
428
  with col1:
429
+ st.markdown("### Prediction")
430
  if result['label'] == 'FAKE':
431
+ st.markdown(f'<div class="flash-message error-message">🔴 This news is likely FAKE (Confidence: {result["confidence"]:.2%})</div>', unsafe_allow_html=True)
432
  else:
433
+ st.markdown(f'<div class="flash-message success-message">🟢 This news is likely REAL (Confidence: {result["confidence"]:.2%})</div>', unsafe_allow_html=True)
434
 
435
  with col2:
436
+ st.markdown("### Confidence Scores")
437
  st.plotly_chart(plot_confidence(result['probabilities']), use_container_width=True)
438
 
439
+ st.markdown("### Attention Analysis")
440
+ st.markdown("""
441
+ <p style="color: #666666;">
442
+ The attention weights show which parts of the text the model focused on while making its prediction. Higher weights indicate more important tokens.
443
+ </p>
444
+ """, unsafe_allow_html=True)
445
  st.plotly_chart(plot_attention(news_text, result['attention_weights']), use_container_width=True)
446
 
447
+ st.markdown("### Model Explanation")
 
448
  if result['label'] == 'FAKE':
449
+ st.markdown("""
450
+ <div style="background-color: #F4F7FA; padding: 1rem; border-radius: 8px;">
451
+ <p>The model identified this as fake news based on:</p>
452
+ <ul>
453
+ <li>Linguistic patterns typical of fake news</li>
454
+ <li>Inconsistencies in the content</li>
455
+ <li>Attention weights on suspicious phrases</li>
456
+ </ul>
457
+ </div>
458
+ """, unsafe_allow_html=True)
459
  else:
460
+ st.markdown("""
461
+ <div style="background-color: #F4F7FA; padding: 1rem; border-radius: 8px;">
462
+ <p>The model identified this as real news based on:</p>
463
+ <ul>
464
+ <li>Credible language patterns</li>
465
+ <li>Consistent information</li>
466
+ <li>Attention weights on factual statements</li>
467
+ </ul>
468
+ </div>
469
+ """, unsafe_allow_html=True)
470
  else:
471
+ st.markdown('<div class="flash-message error-message">Please enter a news article to analyze.</div>', unsafe_allow_html=True)
472
 
473
  if __name__ == "__main__":
474
+ main()