"Client-server interface custom implementation for seizure detection models."

from common import SEIZURE_DETECTION_MODEL_PATH
from concrete import fhe

from seizure_detection import SeizureDetector


class FHEServer:
    """Server interface to run a FHE circuit for seizure detection."""

    def __init__(self, model_path):
        """Initialize the FHE interface.

        Args:
            model_path (Path): The path to the directory where the circuit is saved.
        """
        self.model_path = model_path

        # Load the FHE circuit
        self.server = fhe.Server.load(self.model_path / "server.zip")

    def run(self, serialized_encrypted_image, serialized_evaluation_keys):
        """Run seizure detection on the server over an encrypted image.

        Args:
            serialized_encrypted_image (bytes): The encrypted and serialized image.
            serialized_evaluation_keys (bytes): The serialized evaluation keys.

        Returns:
            bytes: The encrypted boolean output indicating seizure detection.
        """
        # Deserialize the encrypted input image and the evaluation keys
        encrypted_image = fhe.Value.deserialize(serialized_encrypted_image)
        evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)

        # Execute the seizure detection in FHE
        encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys)

        # Serialize the encrypted output
        serialized_encrypted_output = encrypted_output.serialize()

        return serialized_encrypted_output


class FHEDev:
    """Development interface to save and load the seizure detection model."""

    def __init__(self, seizure_detector, model_path):
        """Initialize the FHE interface.

        Args:
            seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface.
            model_path (str): The path to the directory where the circuit is saved.
        """

        self.seizure_detector = seizure_detector
        self.model_path = model_path

        self.model_path.mkdir(parents=True, exist_ok=True)

    def save(self):
        """Export all needed artifacts for the client and server interfaces."""

        assert self.seizure_detector.fhe_circuit is not None, (
            "The model must be compiled before saving it."
        )

        # Save the circuit for the server, using the via_mlir in order to handle cross-platform
        # execution
        path_circuit_server = self.model_path / "server.zip"
        self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True)

        # Save the circuit for the client
        path_circuit_client = self.model_path / "client.zip"
        self.seizure_detector.fhe_circuit.client.save(path_circuit_client)


class FHEClient:
    """Client interface to encrypt and decrypt FHE data associated to a SeizureDetector."""

    def __init__(self, key_dir=None):
        """Initialize the FHE interface.

        Args:
            model_path (Path): The path to the directory where the circuit is saved.
            key_dir (Path): The path to the directory where the keys are stored. Default to None.
        """
        self.model_path = SEIZURE_DETECTION_MODEL_PATH
        self.key_dir = key_dir

        print(self.model_path)

        # If model_path does not exist raise
        assert self.model_path.exists(), f"{self.model_path} does not exist. Please specify a valid path."

        # Load the client
        self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir)

        # Instantiate the seizure detector
        self.seizure_detector = SeizureDetector()

    def generate_private_and_evaluation_keys(self, force=False):
        """Generate the private and evaluation keys.

        Args:
            force (bool): If True, regenerate the keys even if they already exist.
        """
        self.client.keygen(force)

    def get_serialized_evaluation_keys(self):
        """Get the serialized evaluation keys.

        Returns:
            bytes: The evaluation keys.
        """
        return self.client.evaluation_keys.serialize()

    def encrypt_serialize(self, input_image):
        """Encrypt and serialize the input image in the clear.

        Args:
            input_image (numpy.ndarray): The image to encrypt and serialize.

        Returns:
            bytes: The pre-processed, encrypted and serialized image.
        """
        # Encrypt the image
        encrypted_image = self.client.encrypt(input_image)

        # Serialize the encrypted image to be sent to the server
        serialized_encrypted_image = encrypted_image.serialize()
        return serialized_encrypted_image

    def deserialize_decrypt_post_process(self, serialized_encrypted_output):
        """Deserialize, decrypt and post-process the output in the clear.

        Args:
            serialized_encrypted_output (bytes): The serialized and encrypted output.

        Returns:
            bool: The decrypted and deserialized boolean indicating seizure detection.
        """
        # Deserialize the encrypted output
        encrypted_output = fhe.Value.deserialize(serialized_encrypted_output)

        # Decrypt the output
        output = self.client.decrypt(encrypted_output)

        # Post-process the output (if needed)
        seizure_detected = self.seizure_detector.post_processing(output)

        return seizure_detected