fake-detection / main.py
kowalsky's picture
j
041e763
raw
history blame
3.16 kB
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse
import sounddevice as sd
import numpy as np
import librosa
import joblib
import uvicorn
import threading
import asyncio
import logging
from typing import List
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
async def get(request: Request):
logger.info("Saving the index page")
with open("templates/index.html") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
@app.get("/health")
def health_check():
return {"status": "ok"}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
is_detecting = False
detection_thread = None
model = joblib.load('models/xgb_test.pkl')
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_message(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
def extract_features(audio):
sr = 16000
mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
mfccs = np.mean(mfccs, axis=1)
chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
chroma = np.mean(chroma, axis=1)
contrast = librosa.feature.spectral_contrast(y=audio, sr=sr)
contrast = np.mean(contrast, axis=1)
centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)
centroid = np.mean(centroid, axis=1)
combined_features = np.hstack([mfccs, chroma, contrast, centroid])
return combined_features
async def process_audio_data(audio_data):
audio_np = np.frombuffer(audio_data, dtype=np.float32)
features = extract_features(audio_np)
features = features.reshape(1, -1)
prediction = model.predict(features)
is_fake = prediction[0]
result = 'fake' if is_fake else 'real'
await manager.send_message(result)
@app.post("/start_detection")
async def start_detection():
global is_detecting
if not is_detecting:
is_detecting = True
return JSONResponse(content={'status': 'detection_started'})
@app.post("/stop_detection")
async def stop_detection():
global is_detecting
is_detecting = False
return JSONResponse(content={'status': 'detection_stopped'})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_bytes()
await process_audio_data(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=7860)