Krish Patel
commited on
Commit
·
4adafc2
1
Parent(s):
8c24dde
Seperated knowledge graph and Model responses
Browse files- app.py +62 -0
- final.py +7 -91
- nexus-frontend/src/pages/Dashboard.tsx +54 -3
app.py
CHANGED
@@ -4,6 +4,9 @@ from pydantic import BaseModel
|
|
4 |
from final import predict_news, get_gemini_analysis
|
5 |
import os
|
6 |
from tempfile import NamedTemporaryFile
|
|
|
|
|
|
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
@@ -81,3 +84,62 @@ async def detect_deepfake(file: UploadFile = File(...)):
|
|
81 |
|
82 |
except Exception as e:
|
83 |
return {"error": str(e)}, 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from final import predict_news, get_gemini_analysis
|
5 |
import os
|
6 |
from tempfile import NamedTemporaryFile
|
7 |
+
from knowledge_graph_generator import KnowledgeGraphBuilder
|
8 |
+
import networkx as nx
|
9 |
+
import plotly.graph_objects as go
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
|
|
84 |
|
85 |
except Exception as e:
|
86 |
return {"error": str(e)}, 500
|
87 |
+
|
88 |
+
@app.post("/generate-knowledge-graph")
|
89 |
+
async def generate_knowledge_graph(news: NewsInput):
|
90 |
+
kg_builder = KnowledgeGraphBuilder()
|
91 |
+
is_fake = predict_news(news.text) == "FAKE"
|
92 |
+
kg_builder.update_knowledge_graph(news.text, not is_fake)
|
93 |
+
|
94 |
+
pos = nx.spring_layout(kg_builder.knowledge_graph)
|
95 |
+
|
96 |
+
# Create edge traces with different colors
|
97 |
+
edge_trace = go.Scatter(
|
98 |
+
x=[], y=[],
|
99 |
+
line=dict(
|
100 |
+
width=2,
|
101 |
+
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Using rgba for transparency
|
102 |
+
),
|
103 |
+
hoverinfo='none',
|
104 |
+
mode='lines'
|
105 |
+
)
|
106 |
+
|
107 |
+
node_trace = go.Scatter(
|
108 |
+
x=[], y=[],
|
109 |
+
mode='markers+text',
|
110 |
+
hoverinfo='text',
|
111 |
+
textposition='top center',
|
112 |
+
marker=dict(
|
113 |
+
size=15,
|
114 |
+
color='white',
|
115 |
+
line=dict(width=2, color='black')
|
116 |
+
),
|
117 |
+
text=[]
|
118 |
+
)
|
119 |
+
|
120 |
+
# Add edges to visualization
|
121 |
+
for edge in kg_builder.knowledge_graph.edges():
|
122 |
+
x0, y0 = pos[edge[0]]
|
123 |
+
x1, y1 = pos[edge[1]]
|
124 |
+
edge_trace['x'] += (x0, x1, None)
|
125 |
+
edge_trace['y'] += (y0, y1, None)
|
126 |
+
|
127 |
+
# Add nodes to visualization
|
128 |
+
for node in kg_builder.knowledge_graph.nodes():
|
129 |
+
x, y = pos[node]
|
130 |
+
node_trace['x'] += (x,)
|
131 |
+
node_trace['y'] += (y,)
|
132 |
+
node_trace['text'] += (node,)
|
133 |
+
|
134 |
+
fig = go.Figure(data=[edge_trace, node_trace],
|
135 |
+
layout=go.Layout(
|
136 |
+
showlegend=False,
|
137 |
+
hovermode='closest',
|
138 |
+
margin=dict(b=0,l=0,r=0,t=0),
|
139 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
140 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
141 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
142 |
+
paper_bgcolor='rgba(0,0,0,0)'
|
143 |
+
))
|
144 |
+
|
145 |
+
return fig.to_html()
|
final.py
CHANGED
@@ -1,9 +1,6 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
-
import networkx as nx
|
4 |
import spacy
|
5 |
-
import pickle
|
6 |
-
import pandas as pd
|
7 |
import google.generativeai as genai
|
8 |
import json
|
9 |
import os
|
@@ -20,23 +17,10 @@ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
|
20 |
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
21 |
model.eval()
|
22 |
|
23 |
-
#########################
|
24 |
def setup_gemini():
|
25 |
genai.configure(api_key=os.getenv("GEMINI_API"))
|
26 |
model = genai.GenerativeModel('gemini-pro')
|
27 |
return model
|
28 |
-
#########################
|
29 |
-
|
30 |
-
# Load the knowledge graph
|
31 |
-
graph_path = "./models/knowledge_graph.pkl" # Replace with the actual path to your knowledge graph
|
32 |
-
with open(graph_path, 'rb') as f:
|
33 |
-
graph_data = pickle.load(f)
|
34 |
-
|
35 |
-
knowledge_graph = nx.DiGraph()
|
36 |
-
knowledge_graph.add_nodes_from(graph_data['nodes'].items())
|
37 |
-
for u, edges in graph_data['edges'].items():
|
38 |
-
for v, data in edges.items():
|
39 |
-
knowledge_graph.add_edge(u, v, **data)
|
40 |
|
41 |
def predict_with_model(text):
|
42 |
"""Predict whether the news is real or fake using the ML model."""
|
@@ -47,76 +31,17 @@ def predict_with_model(text):
|
|
47 |
predicted_label = torch.argmax(probabilities, dim=-1).item()
|
48 |
return "FAKE" if predicted_label == 1 else "REAL"
|
49 |
|
50 |
-
def update_knowledge_graph(text, is_real):
|
51 |
-
"""Update the knowledge graph with the new article."""
|
52 |
-
entities = extract_entities(text)
|
53 |
-
for entity, entity_type in entities:
|
54 |
-
if not knowledge_graph.has_node(entity):
|
55 |
-
knowledge_graph.add_node(
|
56 |
-
entity,
|
57 |
-
type=entity_type,
|
58 |
-
real_count=1 if is_real else 0,
|
59 |
-
fake_count=0 if is_real else 1
|
60 |
-
)
|
61 |
-
else:
|
62 |
-
if is_real:
|
63 |
-
knowledge_graph.nodes[entity]['real_count'] += 1
|
64 |
-
else:
|
65 |
-
knowledge_graph.nodes[entity]['fake_count'] += 1
|
66 |
-
|
67 |
-
for i, (entity1, _) in enumerate(entities):
|
68 |
-
for entity2, _ in entities[i+1:]:
|
69 |
-
if not knowledge_graph.has_edge(entity1, entity2):
|
70 |
-
knowledge_graph.add_edge(
|
71 |
-
entity1,
|
72 |
-
entity2,
|
73 |
-
weight=1,
|
74 |
-
is_real=is_real
|
75 |
-
)
|
76 |
-
else:
|
77 |
-
knowledge_graph[entity1][entity2]['weight'] += 1
|
78 |
-
|
79 |
def extract_entities(text):
|
80 |
"""Extract named entities from text using spaCy."""
|
81 |
doc = nlp(text)
|
82 |
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
83 |
return entities
|
84 |
|
85 |
-
def predict_with_knowledge_graph(text):
|
86 |
-
"""Predict whether the news is real or fake using the knowledge graph."""
|
87 |
-
entities = extract_entities(text)
|
88 |
-
real_score = 0
|
89 |
-
fake_score = 0
|
90 |
-
|
91 |
-
for entity, _ in entities:
|
92 |
-
if knowledge_graph.has_node(entity):
|
93 |
-
real_count = knowledge_graph.nodes[entity].get('real_count', 0)
|
94 |
-
fake_count = knowledge_graph.nodes[entity].get('fake_count', 0)
|
95 |
-
total = real_count + fake_count
|
96 |
-
if total > 0:
|
97 |
-
real_score += real_count / total
|
98 |
-
fake_score += fake_count / total
|
99 |
-
|
100 |
-
if real_score > fake_score:
|
101 |
-
return "REAL"
|
102 |
-
else:
|
103 |
-
return "FAKE"
|
104 |
-
|
105 |
def predict_news(text):
|
106 |
-
"""Predict whether the news is real or fake using
|
107 |
# Predict with the ML model
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
# Update the knowledge graph
|
112 |
-
update_knowledge_graph(text, is_real)
|
113 |
-
|
114 |
-
# Predict with the knowledge graph
|
115 |
-
kg_prediction = predict_with_knowledge_graph(text)
|
116 |
-
|
117 |
-
# Combine predictions (for simplicity, we use the ML model's prediction here)
|
118 |
-
# You can enhance this by combining the scores from both predictions
|
119 |
-
return ml_prediction if ml_prediction == kg_prediction else "UNCERTAIN"
|
120 |
|
121 |
def analyze_content_gemini(model, text):
|
122 |
prompt = f"""Analyze this news text and return a JSON object with the following structure:
|
@@ -164,16 +89,12 @@ def analyze_content_gemini(model, text):
|
|
164 |
Analyze this text and return only the JSON response: {text}"""
|
165 |
|
166 |
response = model.generate_content(prompt)
|
167 |
-
# return json.loads(response.text)
|
168 |
-
# Add error handling and response cleaning
|
169 |
try:
|
170 |
-
# Clean the response text to ensure it's valid JSON
|
171 |
cleaned_text = response.text.strip()
|
172 |
if cleaned_text.startswith('```json'):
|
173 |
-
cleaned_text = cleaned_text[7:-3]
|
174 |
return json.loads(cleaned_text)
|
175 |
except json.JSONDecodeError:
|
176 |
-
# Return a default structured response if JSON parsing fails
|
177 |
return {
|
178 |
"gemini_analysis": {
|
179 |
"predicted_classification": "UNCERTAIN",
|
@@ -182,7 +103,6 @@ def analyze_content_gemini(model, text):
|
|
182 |
}
|
183 |
}
|
184 |
|
185 |
-
|
186 |
def clean_gemini_output(text):
|
187 |
"""Remove markdown formatting from Gemini output"""
|
188 |
text = text.replace('##', '')
|
@@ -193,10 +113,7 @@ def get_gemini_analysis(text):
|
|
193 |
"""Get detailed content analysis from Gemini."""
|
194 |
gemini_model = setup_gemini()
|
195 |
gemini_analysis = analyze_content_gemini(gemini_model, text)
|
196 |
-
# cleaned_analysis = clean_gemini_output(gemini_analysis)
|
197 |
-
# return cleaned_analysis
|
198 |
return gemini_analysis
|
199 |
-
#########################
|
200 |
|
201 |
def main():
|
202 |
print("Welcome to the News Classifier!")
|
@@ -209,15 +126,14 @@ def main():
|
|
209 |
print("Thank you for using the News Classifier!")
|
210 |
return
|
211 |
|
212 |
-
#
|
213 |
prediction = predict_news(news_text)
|
214 |
-
print(f"\nML
|
215 |
|
216 |
-
#
|
217 |
print("\n=== Detailed Gemini Analysis ===")
|
218 |
gemini_result = get_gemini_analysis(news_text)
|
219 |
print(gemini_result)
|
220 |
|
221 |
-
|
222 |
if __name__ == "__main__":
|
223 |
main()
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
3 |
import spacy
|
|
|
|
|
4 |
import google.generativeai as genai
|
5 |
import json
|
6 |
import os
|
|
|
17 |
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
18 |
model.eval()
|
19 |
|
|
|
20 |
def setup_gemini():
|
21 |
genai.configure(api_key=os.getenv("GEMINI_API"))
|
22 |
model = genai.GenerativeModel('gemini-pro')
|
23 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def predict_with_model(text):
|
26 |
"""Predict whether the news is real or fake using the ML model."""
|
|
|
31 |
predicted_label = torch.argmax(probabilities, dim=-1).item()
|
32 |
return "FAKE" if predicted_label == 1 else "REAL"
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def extract_entities(text):
|
35 |
"""Extract named entities from text using spaCy."""
|
36 |
doc = nlp(text)
|
37 |
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
38 |
return entities
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def predict_news(text):
|
41 |
+
"""Predict whether the news is real or fake using the ML model."""
|
42 |
# Predict with the ML model
|
43 |
+
prediction = predict_with_model(text)
|
44 |
+
return prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def analyze_content_gemini(model, text):
|
47 |
prompt = f"""Analyze this news text and return a JSON object with the following structure:
|
|
|
89 |
Analyze this text and return only the JSON response: {text}"""
|
90 |
|
91 |
response = model.generate_content(prompt)
|
|
|
|
|
92 |
try:
|
|
|
93 |
cleaned_text = response.text.strip()
|
94 |
if cleaned_text.startswith('```json'):
|
95 |
+
cleaned_text = cleaned_text[7:-3]
|
96 |
return json.loads(cleaned_text)
|
97 |
except json.JSONDecodeError:
|
|
|
98 |
return {
|
99 |
"gemini_analysis": {
|
100 |
"predicted_classification": "UNCERTAIN",
|
|
|
103 |
}
|
104 |
}
|
105 |
|
|
|
106 |
def clean_gemini_output(text):
|
107 |
"""Remove markdown formatting from Gemini output"""
|
108 |
text = text.replace('##', '')
|
|
|
113 |
"""Get detailed content analysis from Gemini."""
|
114 |
gemini_model = setup_gemini()
|
115 |
gemini_analysis = analyze_content_gemini(gemini_model, text)
|
|
|
|
|
116 |
return gemini_analysis
|
|
|
117 |
|
118 |
def main():
|
119 |
print("Welcome to the News Classifier!")
|
|
|
126 |
print("Thank you for using the News Classifier!")
|
127 |
return
|
128 |
|
129 |
+
# Get ML prediction
|
130 |
prediction = predict_news(news_text)
|
131 |
+
print(f"\nML Analysis: {prediction}")
|
132 |
|
133 |
+
# Get Gemini analysis
|
134 |
print("\n=== Detailed Gemini Analysis ===")
|
135 |
gemini_result = get_gemini_analysis(news_text)
|
136 |
print(gemini_result)
|
137 |
|
|
|
138 |
if __name__ == "__main__":
|
139 |
main()
|
nexus-frontend/src/pages/Dashboard.tsx
CHANGED
@@ -110,6 +110,7 @@ const AnimatedBackground = memo(() => {
|
|
110 |
});
|
111 |
|
112 |
const Dashboard = () => {
|
|
|
113 |
const [inputText, setInputText] = useState('');
|
114 |
const [isProcessing, setIsProcessing] = useState(false);
|
115 |
const [result, setResult] = useState<{
|
@@ -184,6 +185,24 @@ const Dashboard = () => {
|
|
184 |
}
|
185 |
};
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
return (
|
188 |
<div className="min-h-screen relative overflow-hidden">
|
189 |
<AnimatedBackground />
|
@@ -203,13 +222,30 @@ const Dashboard = () => {
|
|
203 |
className="w-full h-40 p-4 rounded-xl bg-gray-700 text-white placeholder-gray-400 focus:outline-none"
|
204 |
placeholder="Enter text to analyze..."
|
205 |
/>
|
206 |
-
<button
|
207 |
onClick={processInput}
|
208 |
disabled={isProcessing || !inputText.trim()}
|
209 |
className="absolute bottom-4 right-4 px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
|
210 |
>
|
211 |
Analyze
|
212 |
-
</button>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
</div>
|
214 |
|
215 |
{isProcessing && (
|
@@ -293,7 +329,22 @@ const Dashboard = () => {
|
|
293 |
</div>
|
294 |
</div>
|
295 |
)}
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
</div>
|
299 |
</div>
|
|
|
110 |
});
|
111 |
|
112 |
const Dashboard = () => {
|
113 |
+
const [knowledgeGraphData, setKnowledgeGraphData] = useState(null);
|
114 |
const [inputText, setInputText] = useState('');
|
115 |
const [isProcessing, setIsProcessing] = useState(false);
|
116 |
const [result, setResult] = useState<{
|
|
|
185 |
}
|
186 |
};
|
187 |
|
188 |
+
const generateKnowledgeGraph = async () => {
|
189 |
+
try {
|
190 |
+
const response = await fetch('http://localhost:8000/generate-knowledge-graph', {
|
191 |
+
method: 'POST',
|
192 |
+
headers: {
|
193 |
+
'Content-Type': 'application/json',
|
194 |
+
},
|
195 |
+
body: JSON.stringify({ text: inputText }),
|
196 |
+
});
|
197 |
+
|
198 |
+
const data = await response.json();
|
199 |
+
setKnowledgeGraphData(data);
|
200 |
+
} catch (err) {
|
201 |
+
console.error('Knowledge graph generation error:', err);
|
202 |
+
}
|
203 |
+
};
|
204 |
+
|
205 |
+
|
206 |
return (
|
207 |
<div className="min-h-screen relative overflow-hidden">
|
208 |
<AnimatedBackground />
|
|
|
222 |
className="w-full h-40 p-4 rounded-xl bg-gray-700 text-white placeholder-gray-400 focus:outline-none"
|
223 |
placeholder="Enter text to analyze..."
|
224 |
/>
|
225 |
+
{/* <button
|
226 |
onClick={processInput}
|
227 |
disabled={isProcessing || !inputText.trim()}
|
228 |
className="absolute bottom-4 right-4 px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
|
229 |
>
|
230 |
Analyze
|
231 |
+
</button> */}
|
232 |
+
<div className="flex gap-2">
|
233 |
+
<button
|
234 |
+
onClick={processInput}
|
235 |
+
disabled={isProcessing || !inputText.trim()}
|
236 |
+
className="px-6 py-2 rounded-lg bg-blue-500 hover:bg-blue-600 text-white disabled:opacity-50"
|
237 |
+
>
|
238 |
+
Analyze
|
239 |
+
</button>
|
240 |
+
<button
|
241 |
+
onClick={generateKnowledgeGraph}
|
242 |
+
disabled={!inputText.trim()}
|
243 |
+
className="px-6 py-2 rounded-lg bg-green-500 hover:bg-green-600 text-white disabled:opacity-50"
|
244 |
+
>
|
245 |
+
Visualize Graph
|
246 |
+
</button>
|
247 |
+
</div>
|
248 |
+
|
249 |
</div>
|
250 |
|
251 |
{isProcessing && (
|
|
|
329 |
</div>
|
330 |
</div>
|
331 |
)}
|
332 |
+
{knowledgeGraphData && (
|
333 |
+
<div className="mt-4 bg-gray-800 rounded-lg p-4">
|
334 |
+
<h3 className="text-xl font-semibold text-white mb-3">Knowledge Graph</h3>
|
335 |
+
<div className="w-full h-[500px] bg-gray-900 rounded-lg overflow-hidden">
|
336 |
+
<iframe
|
337 |
+
srcDoc={knowledgeGraphData}
|
338 |
+
className="w-full h-full border-0"
|
339 |
+
style={{
|
340 |
+
backgroundColor: 'transparent',
|
341 |
+
width: '100%',
|
342 |
+
height: '100%'
|
343 |
+
}}
|
344 |
+
/>
|
345 |
+
</div>
|
346 |
+
</div>
|
347 |
+
)}
|
348 |
|
349 |
</div>
|
350 |
</div>
|