File size: 3,637 Bytes
9a997e4 c119738 74c0c8e c119738 74c0c8e c119738 74c0c8e c119738 74c0c8e c119738 74c0c8e c119738 18ba8c1 74c0c8e 18ba8c1 74c0c8e 316f8e9 74c0c8e 18ba8c1 c119738 9a997e4 c119738 9a997e4 c119738 9a997e4 c119738 9a997e4 c119738 74c0c8e c119738 74c0c8e c119738 74c0c8e c119738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
"""Modified classes for use for Client-Server interface with multi-inputs circuits."""
import numpy
import copy
from typing import Tuple
from concrete.fhe import Value, EvaluationKeys
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import DecisionTreeClassifier
class MultiInputsFHEModelDev(FHEModelDev):
def __init__(self, *arg, **kwargs):
super().__init__(*arg, **kwargs)
# Workaround that enables loading a modified version of a DecisionTreeClassifier model
model = copy.copy(self.model)
model.__class__ = DecisionTreeClassifier
self.model = model
class MultiInputsFHEModelClient(FHEModelClient):
def __init__(self, *args, nb_inputs=1, **kwargs):
self.nb_inputs = nb_inputs
super().__init__(*args, **kwargs)
def quantize_encrypt_serialize_multi_inputs(
self,
x: numpy.ndarray,
input_index: int,
processed_input_shape: Tuple[int],
input_slice: slice,
) -> bytes:
"""Quantize, encrypt and serialize inputs for a multi-party model.
In the following, the 'quantize_input' method called is the one defined in Concrete ML's
built-in models. Since they don't natively handle inputs for multi-party models, we need
to use padding along indexing and slicing so that inputs from a specific party are correctly
associated with input quantizers.
Args:
x (numpy.ndarray): The input to consider. Here, the input should only represent a
single party.
input_index (int): The index representing the type of model (0: "user", 1: "bank",
2: "cs_agency")
processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
pre-processing.
input_slice (slice): The slices to consider for the given party.
"""
x_padded = numpy.zeros(processed_input_shape)
x_padded[:, input_slice] = x
q_x_padded = self.model.quantize_input(x_padded)
q_x = q_x_padded[:, input_slice]
q_x_inputs = [None for _ in range(self.nb_inputs)]
q_x_inputs[input_index] = q_x
# Encrypt the values
q_x_enc = self.client.encrypt(*q_x_inputs)
# Serialize the encrypted values to be sent to the server
q_x_enc_ser = q_x_enc[input_index].serialize()
return q_x_enc_ser
class MultiInputsFHEModelServer(FHEModelServer):
def run(
self,
*serialized_encrypted_quantized_data: Tuple[bytes],
serialized_evaluation_keys: bytes,
) -> bytes:
"""Run the model on the server over encrypted data for a multi-party model.
Args:
serialized_encrypted_quantized_data (Tuple[bytes]): The encrypted, quantized
and serialized data for a multi-party model.
serialized_evaluation_keys (bytes): The serialized evaluation key.
Returns:
bytes: the result of the model
"""
assert self.server is not None, "Model has not been loaded."
deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = self.server.run(
*deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
)
serialized_result = result.serialize()
return serialized_result |