Spaces:
Sleeping
Sleeping
"""Server that will listen for GET and POST requests from the client.""" | |
import time | |
import logging | |
from pathlib import Path | |
from typing import List | |
from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from concrete.ml.deployment import FHEModelServer | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Initialize the FHE server | |
DEPLOYMENT_DIR = Path(__file__).parent / "deployment_files" | |
FHE_SERVER = FHEModelServer(DEPLOYMENT_DIR) | |
def get_server_file_path(file_type: str, user_id: str) -> Path: | |
"""Get the path to a file on the server.""" | |
return Path(__file__).parent / "server_tmp" / f"{file_type}_{user_id}" | |
async def send_input(user_id: str = Form(), files: List[UploadFile] = File(...)): | |
"""Receive the encrypted input image and the evaluation key from the client.""" | |
try: | |
for file in files: | |
file_path = get_server_file_path(file.filename.split("_")[0], user_id) | |
with file_path.open("wb") as buffer: | |
content = await file.read() | |
buffer.write(content) | |
return JSONResponse(content={"message": "Files received successfully"}) | |
except Exception as e: | |
logger.error(f"Error in send_input: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
def run_fhe(user_id: str = Form()): | |
"""Execute seizure detection on the encrypted input image using FHE.""" | |
logger.info(f"Starting FHE execution for user {user_id}") | |
try: | |
# Retrieve the encrypted input image and the evaluation key paths | |
encrypted_image_path = get_server_file_path("encrypted_image", user_id) | |
evaluation_key_path = get_server_file_path("evaluation_key", user_id) | |
# Check if files exist | |
if not encrypted_image_path.exists() or not evaluation_key_path.exists(): | |
raise FileNotFoundError("Required files not found") | |
# Read the files using the above paths | |
with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open("rb") as evaluation_key_file: | |
encrypted_image = encrypted_image_file.read() | |
evaluation_key = evaluation_key_file.read() | |
# Run the FHE execution | |
start = time.time() | |
encrypted_output = FHE_SERVER.run(encrypted_image, evaluation_key) | |
fhe_execution_time = round(time.time() - start, 2) | |
# Retrieve the encrypted output path | |
encrypted_output_path = get_server_file_path("encrypted_output", user_id) | |
# Write the file using the above path | |
with encrypted_output_path.open("wb") as encrypted_output_file: | |
encrypted_output_file.write(encrypted_output) | |
logger.info(f"FHE execution completed for user {user_id} in {fhe_execution_time} seconds") | |
return JSONResponse(content=fhe_execution_time) | |
except Exception as e: | |
logger.error(f"Error in run_fhe for user {user_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
def get_output(user_id: str = Form()): | |
"""Retrieve the encrypted output.""" | |
try: | |
# Retrieve the encrypted output path | |
encrypted_output_path = get_server_file_path("encrypted_output", user_id) | |
# Check if file exists | |
if not encrypted_output_path.exists(): | |
raise FileNotFoundError("Encrypted output file not found") | |
# Read the file using the above path | |
with encrypted_output_path.open("rb") as encrypted_output_file: | |
encrypted_output = encrypted_output_file.read() | |
return Response(encrypted_output) | |
except Exception as e: | |
logger.error(f"Error in get_output for user {user_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |