Krish Patel commited on
Commit
4adafc2
·
1 Parent(s): 8c24dde

Seperated knowledge graph and Model responses

Browse files
Files changed (3) hide show
  1. app.py +62 -0
  2. final.py +7 -91
  3. nexus-frontend/src/pages/Dashboard.tsx +54 -3
app.py CHANGED
@@ -4,6 +4,9 @@ from pydantic import BaseModel
4
  from final import predict_news, get_gemini_analysis
5
  import os
6
  from tempfile import NamedTemporaryFile
 
 
 
7
 
8
  app = FastAPI()
9
 
@@ -81,3 +84,62 @@ async def detect_deepfake(file: UploadFile = File(...)):
81
 
82
  except Exception as e:
83
  return {"error": str(e)}, 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from final import predict_news, get_gemini_analysis
5
  import os
6
  from tempfile import NamedTemporaryFile
7
+ from knowledge_graph_generator import KnowledgeGraphBuilder
8
+ import networkx as nx
9
+ import plotly.graph_objects as go
10
 
11
  app = FastAPI()
12
 
 
84
 
85
  except Exception as e:
86
  return {"error": str(e)}, 500
87
+
88
+ @app.post("/generate-knowledge-graph")
89
+ async def generate_knowledge_graph(news: NewsInput):
90
+ kg_builder = KnowledgeGraphBuilder()
91
+ is_fake = predict_news(news.text) == "FAKE"
92
+ kg_builder.update_knowledge_graph(news.text, not is_fake)
93
+
94
+ pos = nx.spring_layout(kg_builder.knowledge_graph)
95
+
96
+ # Create edge traces with different colors
97
+ edge_trace = go.Scatter(
98
+ x=[], y=[],
99
+ line=dict(
100
+ width=2,
101
+ color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Using rgba for transparency
102
+ ),
103
+ hoverinfo='none',
104
+ mode='lines'
105
+ )
106
+
107
+ node_trace = go.Scatter(
108
+ x=[], y=[],
109
+ mode='markers+text',
110
+ hoverinfo='text',
111
+ textposition='top center',
112
+ marker=dict(
113
+ size=15,
114
+ color='white',
115
+ line=dict(width=2, color='black')
116
+ ),
117
+ text=[]
118
+ )
119
+
120
+ # Add edges to visualization
121
+ for edge in kg_builder.knowledge_graph.edges():
122
+ x0, y0 = pos[edge[0]]
123
+ x1, y1 = pos[edge[1]]
124
+ edge_trace['x'] += (x0, x1, None)
125
+ edge_trace['y'] += (y0, y1, None)
126
+
127
+ # Add nodes to visualization
128
+ for node in kg_builder.knowledge_graph.nodes():
129
+ x, y = pos[node]
130
+ node_trace['x'] += (x,)
131
+ node_trace['y'] += (y,)
132
+ node_trace['text'] += (node,)
133
+
134
+ fig = go.Figure(data=[edge_trace, node_trace],
135
+ layout=go.Layout(
136
+ showlegend=False,
137
+ hovermode='closest',
138
+ margin=dict(b=0,l=0,r=0,t=0),
139
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
140
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
141
+ plot_bgcolor='rgba(0,0,0,0)',
142
+ paper_bgcolor='rgba(0,0,0,0)'
143
+ ))
144
+
145
+ return fig.to_html()
final.py CHANGED
@@ -1,9 +1,6 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import networkx as nx
4
  import spacy
5
- import pickle
6
- import pandas as pd
7
  import google.generativeai as genai
8
  import json
9
  import os
@@ -20,23 +17,10 @@ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
20
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
21
  model.eval()
22
 
23
- #########################
24
  def setup_gemini():
25
  genai.configure(api_key=os.getenv("GEMINI_API"))
26
  model = genai.GenerativeModel('gemini-pro')
27
  return model
28
- #########################
29
-
30
- # Load the knowledge graph
31
- graph_path = "./models/knowledge_graph.pkl" # Replace with the actual path to your knowledge graph
32
- with open(graph_path, 'rb') as f:
33
- graph_data = pickle.load(f)
34
-
35
- knowledge_graph = nx.DiGraph()
36
- knowledge_graph.add_nodes_from(graph_data['nodes'].items())
37
- for u, edges in graph_data['edges'].items():
38
- for v, data in edges.items():
39
- knowledge_graph.add_edge(u, v, **data)
40
 
41
  def predict_with_model(text):
42
  """Predict whether the news is real or fake using the ML model."""
@@ -47,76 +31,17 @@ def predict_with_model(text):
47
  predicted_label = torch.argmax(probabilities, dim=-1).item()
48
  return "FAKE" if predicted_label == 1 else "REAL"
49
 
50
- def update_knowledge_graph(text, is_real):
51
- """Update the knowledge graph with the new article."""
52
- entities = extract_entities(text)
53
- for entity, entity_type in entities:
54
- if not knowledge_graph.has_node(entity):
55
- knowledge_graph.add_node(
56
- entity,
57
- type=entity_type,
58
- real_count=1 if is_real else 0,
59
- fake_count=0 if is_real else 1
60
- )
61
- else:
62
- if is_real:
63
- knowledge_graph.nodes[entity]['real_count'] += 1
64
- else:
65
- knowledge_graph.nodes[entity]['fake_count'] += 1
66
-
67
- for i, (entity1, _) in enumerate(entities):
68
- for entity2, _ in entities[i+1:]:
69
- if not knowledge_graph.has_edge(entity1, entity2):
70
- knowledge_graph.add_edge(
71
- entity1,
72
- entity2,
73
- weight=1,
74
- is_real=is_real
75
- )
76
- else:
77
- knowledge_graph[entity1][entity2]['weight'] += 1
78
-
79
  def extract_entities(text):
80
  """Extract named entities from text using spaCy."""
81
  doc = nlp(text)
82
  entities = [(ent.text, ent.label_) for ent in doc.ents]
83
  return entities
84
 
85
- def predict_with_knowledge_graph(text):
86
- """Predict whether the news is real or fake using the knowledge graph."""
87
- entities = extract_entities(text)
88
- real_score = 0
89
- fake_score = 0
90
-
91
- for entity, _ in entities:
92
- if knowledge_graph.has_node(entity):
93
- real_count = knowledge_graph.nodes[entity].get('real_count', 0)
94
- fake_count = knowledge_graph.nodes[entity].get('fake_count', 0)
95
- total = real_count + fake_count
96
- if total > 0:
97
- real_score += real_count / total
98
- fake_score += fake_count / total
99
-
100
- if real_score > fake_score:
101
- return "REAL"
102
- else:
103
- return "FAKE"
104
-
105
  def predict_news(text):
106
- """Predict whether the news is real or fake using both the ML model and the knowledge graph."""
107
  # Predict with the ML model
108
- ml_prediction = predict_with_model(text)
109
- is_real = ml_prediction == "REAL"
110
-
111
- # Update the knowledge graph
112
- update_knowledge_graph(text, is_real)
113
-
114
- # Predict with the knowledge graph
115
- kg_prediction = predict_with_knowledge_graph(text)
116
-
117
- # Combine predictions (for simplicity, we use the ML model's prediction here)
118
- # You can enhance this by combining the scores from both predictions
119
- return ml_prediction if ml_prediction == kg_prediction else "UNCERTAIN"
120
 
121
  def analyze_content_gemini(model, text):
122
  prompt = f"""Analyze this news text and return a JSON object with the following structure:
@@ -164,16 +89,12 @@ def analyze_content_gemini(model, text):
164
  Analyze this text and return only the JSON response: {text}"""
165
 
166
  response = model.generate_content(prompt)
167
- # return json.loads(response.text)
168
- # Add error handling and response cleaning
169
  try:
170
- # Clean the response text to ensure it's valid JSON
171
  cleaned_text = response.text.strip()
172
  if cleaned_text.startswith('```json'):
173
- cleaned_text = cleaned_text[7:-3] # Remove ```json and ``` markers
174
  return json.loads(cleaned_text)
175
  except json.JSONDecodeError:
176
- # Return a default structured response if JSON parsing fails
177
  return {
178
  "gemini_analysis": {
179
  "predicted_classification": "UNCERTAIN",
@@ -182,7 +103,6 @@ def analyze_content_gemini(model, text):
182
  }
183
  }
184
 
185
-
186
  def clean_gemini_output(text):
187
  """Remove markdown formatting from Gemini output"""
188
  text = text.replace('##', '')
@@ -193,10 +113,7 @@ def get_gemini_analysis(text):
193
  """Get detailed content analysis from Gemini."""
194
  gemini_model = setup_gemini()
195
  gemini_analysis = analyze_content_gemini(gemini_model, text)
196
- # cleaned_analysis = clean_gemini_output(gemini_analysis)
197
- # return cleaned_analysis
198
  return gemini_analysis
199
- #########################
200
 
201
  def main():
202
  print("Welcome to the News Classifier!")
@@ -209,15 +126,14 @@ def main():
209
  print("Thank you for using the News Classifier!")
210
  return
211
 
212
- # First get ML and Knowledge Graph prediction
213
  prediction = predict_news(news_text)
214
- print(f"\nML and Knowledge Graph Analysis: {prediction}")
215
 
216
- # Then get Gemini analysis
217
  print("\n=== Detailed Gemini Analysis ===")
218
  gemini_result = get_gemini_analysis(news_text)
219
  print(gemini_result)
220
 
221
-
222
  if __name__ == "__main__":
223
  main()
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
3
  import spacy
 
 
4
  import google.generativeai as genai
5
  import json
6
  import os
 
17
  model = AutoModelForSequenceClassification.from_pretrained(model_path)
18
  model.eval()
19
 
 
20
  def setup_gemini():
21
  genai.configure(api_key=os.getenv("GEMINI_API"))
22
  model = genai.GenerativeModel('gemini-pro')
23
  return model
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def predict_with_model(text):
26
  """Predict whether the news is real or fake using the ML model."""
 
31
  predicted_label = torch.argmax(probabilities, dim=-1).item()
32
  return "FAKE" if predicted_label == 1 else "REAL"
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def extract_entities(text):
35
  """Extract named entities from text using spaCy."""
36
  doc = nlp(text)
37
  entities = [(ent.text, ent.label_) for ent in doc.ents]
38
  return entities
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def predict_news(text):
41
+ """Predict whether the news is real or fake using the ML model."""
42
  # Predict with the ML model
43
+ prediction = predict_with_model(text)
44
+ return prediction
 
 
 
 
 
 
 
 
 
 
45
 
46
  def analyze_content_gemini(model, text):
47
  prompt = f"""Analyze this news text and return a JSON object with the following structure:
 
89
  Analyze this text and return only the JSON response: {text}"""
90
 
91
  response = model.generate_content(prompt)
 
 
92
  try:
 
93
  cleaned_text = response.text.strip()
94
  if cleaned_text.startswith('```json'):
95
+ cleaned_text = cleaned_text[7:-3]
96
  return json.loads(cleaned_text)
97
  except json.JSONDecodeError:
 
98
  return {
99
  "gemini_analysis": {
100
  "predicted_classification": "UNCERTAIN",
 
103
  }
104
  }
105
 
 
106
  def clean_gemini_output(text):
107
  """Remove markdown formatting from Gemini output"""
108
  text = text.replace('##', '')
 
113
  """Get detailed content analysis from Gemini."""
114
  gemini_model = setup_gemini()
115
  gemini_analysis = analyze_content_gemini(gemini_model, text)
 
 
116
  return gemini_analysis
 
117
 
118
  def main():
119
  print("Welcome to the News Classifier!")
 
126
  print("Thank you for using the News Classifier!")
127
  return
128
 
129
+ # Get ML prediction
130
  prediction = predict_news(news_text)
131
+ print(f"\nML Analysis: {prediction}")
132
 
133
+ # Get Gemini analysis
134
  print("\n=== Detailed Gemini Analysis ===")
135
  gemini_result = get_gemini_analysis(news_text)
136
  print(gemini_result)
137
 
 
138
  if __name__ == "__main__":
139
  main()
nexus-frontend/src/pages/Dashboard.tsx CHANGED
@@ -110,6 +110,7 @@ const AnimatedBackground = memo(() => {
110
  });
111
 
112
  const Dashboard = () => {
 
113
  const [inputText, setInputText] = useState('');
114
  const [isProcessing, setIsProcessing] = useState(false);
115
  const [result, setResult] = useState<{
@@ -184,6 +185,24 @@ const Dashboard = () => {
184
  }
185
  };
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  return (
188
  <div className="min-h-screen relative overflow-hidden">
189
  <AnimatedBackground />
@@ -203,13 +222,30 @@ const Dashboard = () => {
203
  className="w-full h-40 p-4 rounded-xl bg-gray-700 text-white placeholder-gray-400 focus:outline-none"
204
  placeholder="Enter text to analyze..."
205
  />
206
- <button
207
  onClick={processInput}
208
  disabled={isProcessing || !inputText.trim()}
209
  className="absolute bottom-4 right-4 px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
210
  >
211
  Analyze
212
- </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  </div>
214
 
215
  {isProcessing && (
@@ -293,7 +329,22 @@ const Dashboard = () => {
293
  </div>
294
  </div>
295
  )}
296
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  </div>
299
  </div>
 
110
  });
111
 
112
  const Dashboard = () => {
113
+ const [knowledgeGraphData, setKnowledgeGraphData] = useState(null);
114
  const [inputText, setInputText] = useState('');
115
  const [isProcessing, setIsProcessing] = useState(false);
116
  const [result, setResult] = useState<{
 
185
  }
186
  };
187
 
188
+ const generateKnowledgeGraph = async () => {
189
+ try {
190
+ const response = await fetch('http://localhost:8000/generate-knowledge-graph', {
191
+ method: 'POST',
192
+ headers: {
193
+ 'Content-Type': 'application/json',
194
+ },
195
+ body: JSON.stringify({ text: inputText }),
196
+ });
197
+
198
+ const data = await response.json();
199
+ setKnowledgeGraphData(data);
200
+ } catch (err) {
201
+ console.error('Knowledge graph generation error:', err);
202
+ }
203
+ };
204
+
205
+
206
  return (
207
  <div className="min-h-screen relative overflow-hidden">
208
  <AnimatedBackground />
 
222
  className="w-full h-40 p-4 rounded-xl bg-gray-700 text-white placeholder-gray-400 focus:outline-none"
223
  placeholder="Enter text to analyze..."
224
  />
225
+ {/* <button
226
  onClick={processInput}
227
  disabled={isProcessing || !inputText.trim()}
228
  className="absolute bottom-4 right-4 px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
229
  >
230
  Analyze
231
+ </button> */}
232
+ <div className="flex gap-2">
233
+ <button
234
+ onClick={processInput}
235
+ disabled={isProcessing || !inputText.trim()}
236
+ className="px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
237
+ >
238
+ Analyze
239
+ </button>
240
+ <button
241
+ onClick={generateKnowledgeGraph}
242
+ disabled={!inputText.trim()}
243
+ className="px-6 py-2 rounded-lg bg-green-500 hover:bg-green-600 text-white disabled:opacity-50"
244
+ >
245
+ Visualize Graph
246
+ </button>
247
+ </div>
248
+
249
  </div>
250
 
251
  {isProcessing && (
 
329
  </div>
330
  </div>
331
  )}
332
+ {knowledgeGraphData && (
333
+ <div className="mt-4 bg-gray-800 rounded-lg p-4">
334
+ <h3 className="text-xl font-semibold text-white mb-3">Knowledge Graph</h3>
335
+ <div className="w-full h-[500px] bg-gray-900 rounded-lg overflow-hidden">
336
+ <iframe
337
+ srcDoc={knowledgeGraphData}
338
+ className="w-full h-full border-0"
339
+ style={{
340
+ backgroundColor: 'transparent',
341
+ width: '100%',
342
+ height: '100%'
343
+ }}
344
+ />
345
+ </div>
346
+ </div>
347
+ )}
348
 
349
  </div>
350
  </div>