File size: 4,579 Bytes
b6597a0 4adafc2 b6597a0 990f77e b6597a0 990f77e 8c24dde b6597a0 8c24dde b6597a0 8c24dde b6597a0 8c24dde b6597a0 8c24dde b6597a0 4adafc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from final import predict_news, get_gemini_analysis
import os
from tempfile import NamedTemporaryFile
from knowledge_graph_generator import KnowledgeGraphBuilder
import networkx as nx
import plotly.graph_objects as go
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173"], # Your React app's URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Rest of your code remains the same
class NewsInput(BaseModel):
text: str
@app.post("/analyze")
async def analyze_news(news: NewsInput):
prediction = predict_news(news.text)
gemini_analysis = get_gemini_analysis(news.text)
return {
"prediction": prediction,
"detailed_analysis": gemini_analysis
}
# @app.post("/detect-deepfake")
# async def detect_deepfake(image: UploadFile = File(...)):
# try:
# # Save uploaded image temporarily
# with NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
# contents = await image.read()
# temp_file.write(contents)
# temp_file_path = temp_file.name
# # Use your existing deepfake detection function
# from deepfake2.testing2 import predict_image # Use your existing function
# result = predict_image(temp_file_path)
# # Clean up temp file
# os.remove(temp_file_path)
# return result
# except Exception as e:
# return {"error": str(e)}, 500
@app.post("/detect-deepfake")
async def detect_deepfake(file: UploadFile = File(...)):
try:
# Save uploaded file temporarily
with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
contents = await file.read()
temp_file.write(contents)
temp_file_path = temp_file.name
# Import functions from testing2.py
from deepfake2.testing2 import predict_image, predict_video
# Use appropriate function based on file type
if file.filename.lower().endswith('.mp4'):
result = predict_video(temp_file_path)
file_type = "video"
else:
result = predict_image(temp_file_path)
file_type = "image"
# Clean up temp file
os.remove(temp_file_path)
return {
"result": result,
"file_type": file_type
}
except Exception as e:
return {"error": str(e)}, 500
@app.post("/generate-knowledge-graph")
async def generate_knowledge_graph(news: NewsInput):
kg_builder = KnowledgeGraphBuilder()
is_fake = predict_news(news.text) == "FAKE"
kg_builder.update_knowledge_graph(news.text, not is_fake)
pos = nx.spring_layout(kg_builder.knowledge_graph)
# Create edge traces with different colors
edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)' # Using rgba for transparency
),
hoverinfo='none',
mode='lines'
)
node_trace = go.Scatter(
x=[], y=[],
mode='markers+text',
hoverinfo='text',
textposition='top center',
marker=dict(
size=15,
color='white',
line=dict(width=2, color='black')
),
text=[]
)
# Add edges to visualization
for edge in kg_builder.knowledge_graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += (x0, x1, None)
edge_trace['y'] += (y0, y1, None)
# Add nodes to visualization
for node in kg_builder.knowledge_graph.nodes():
x, y = pos[node]
node_trace['x'] += (x,)
node_trace['y'] += (y,)
node_trace['text'] += (node,)
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)'
))
return fig.to_html()
|