|
"""Modified classes for use for Client-Server interface with multi-inputs circuits.""" |
|
|
|
import numpy |
|
import copy |
|
|
|
from typing import Tuple |
|
|
|
from concrete.fhe import Value, EvaluationKeys |
|
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer |
|
from concrete.ml.sklearn import DecisionTreeClassifier |
|
|
|
|
|
class MultiInputsFHEModelDev(FHEModelDev): |
|
|
|
def __init__(self, *arg, **kwargs): |
|
|
|
super().__init__(*arg, **kwargs) |
|
|
|
|
|
model = copy.copy(self.model) |
|
model.__class__ = DecisionTreeClassifier |
|
self.model = model |
|
|
|
|
|
class MultiInputsFHEModelClient(FHEModelClient): |
|
|
|
def __init__(self, *args, nb_inputs=1, **kwargs): |
|
self.nb_inputs = nb_inputs |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def quantize_encrypt_serialize_multi_inputs( |
|
self, |
|
x: numpy.ndarray, |
|
input_index: int, |
|
processed_input_shape: Tuple[int], |
|
input_slice: slice, |
|
) -> bytes: |
|
"""Quantize, encrypt and serialize inputs for a multi-party model. |
|
|
|
In the following, the 'quantize_input' method called is the one defined in Concrete ML's |
|
built-in models. Since they don't natively handle inputs for multi-party models, we need |
|
to use padding along indexing and slicing so that inputs from a specific party are correctly |
|
associated with input quantizers. |
|
|
|
Args: |
|
x (numpy.ndarray): The input to consider. Here, the input should only represent a |
|
single party. |
|
input_index (int): The index representing the type of model (0: "user", 1: "bank", |
|
2: "third_party") |
|
processed_input_shape (Tuple[int]): The total input shape (all parties combined) after |
|
pre-processing. |
|
input_slice (slice): The slices to consider for the given party. |
|
|
|
""" |
|
|
|
x_padded = numpy.zeros(processed_input_shape) |
|
|
|
x_padded[:, input_slice] = x |
|
|
|
q_x_padded = self.model.quantize_input(x_padded) |
|
|
|
q_x = q_x_padded[:, input_slice] |
|
|
|
q_x_inputs = [None for _ in range(self.nb_inputs)] |
|
q_x_inputs[input_index] = q_x |
|
|
|
|
|
q_x_enc = self.client.encrypt(*q_x_inputs) |
|
|
|
|
|
q_x_enc_ser = q_x_enc[input_index].serialize() |
|
return q_x_enc_ser |
|
|
|
|
|
class MultiInputsFHEModelServer(FHEModelServer): |
|
|
|
def run( |
|
self, |
|
*serialized_encrypted_quantized_data: Tuple[bytes], |
|
serialized_evaluation_keys: bytes, |
|
) -> bytes: |
|
"""Run the model on the server over encrypted data for a multi-party model. |
|
|
|
Args: |
|
serialized_encrypted_quantized_data (Tuple[bytes]): The encrypted, quantized |
|
and serialized data for a multi-party model. |
|
serialized_evaluation_keys (bytes): The serialized evaluation key. |
|
|
|
Returns: |
|
bytes: the result of the model |
|
""" |
|
assert self.server is not None, "Model has not been loaded." |
|
|
|
deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data) |
|
|
|
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys) |
|
|
|
result = self.server.run( |
|
*deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys |
|
) |
|
serialized_result = result.serialize() |
|
return serialized_result |