File size: 3,637 Bytes
9a997e4
 
c119738
 
 
74c0c8e
c119738
74c0c8e
c119738
74c0c8e
c119738
 
 
 
 
 
 
 
74c0c8e
c119738
74c0c8e
c119738
 
 
 
 
 
 
 
 
 
18ba8c1
 
 
74c0c8e
 
 
18ba8c1
74c0c8e
 
 
 
 
 
 
 
 
 
 
316f8e9
74c0c8e
 
 
 
 
18ba8c1
 
c119738
9a997e4
c119738
 
 
9a997e4
c119738
9a997e4
 
c119738
 
9a997e4
c119738
 
 
 
 
 
 
 
 
 
74c0c8e
c119738
 
74c0c8e
c119738
 
74c0c8e
 
 
c119738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Modified classes for use for Client-Server interface with multi-inputs circuits."""

import numpy
import copy

from typing import Tuple

from concrete.fhe import Value, EvaluationKeys
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import DecisionTreeClassifier 


class MultiInputsFHEModelDev(FHEModelDev):

    def __init__(self, *arg, **kwargs):

        super().__init__(*arg, **kwargs)
        
        # Workaround that enables loading a modified version of a DecisionTreeClassifier model
        model = copy.copy(self.model)
        model.__class__ = DecisionTreeClassifier
        self.model = model


class MultiInputsFHEModelClient(FHEModelClient):

    def __init__(self, *args, nb_inputs=1, **kwargs):
        self.nb_inputs = nb_inputs

        super().__init__(*args, **kwargs)
    
    def quantize_encrypt_serialize_multi_inputs(
        self, 
        x: numpy.ndarray, 
        input_index: int, 
        processed_input_shape: Tuple[int], 
        input_slice: slice,
    ) -> bytes:
        """Quantize, encrypt and serialize inputs for a multi-party model.

        In the following, the 'quantize_input' method called is the one defined in Concrete ML's 
        built-in models. Since they don't natively handle inputs for multi-party models, we need
        to use padding along indexing and slicing so that inputs from a specific party are correctly 
        associated with input quantizers.
        
        Args:
            x (numpy.ndarray): The input to consider. Here, the input should only represent a
                single party.
            input_index (int): The index representing the type of model (0: "user", 1: "bank", 
                2: "cs_agency")
            processed_input_shape (Tuple[int]): The total input shape (all parties combined) after
                pre-processing.
            input_slice (slice): The slices to consider for the given party.

        """

        x_padded = numpy.zeros(processed_input_shape)

        x_padded[:, input_slice] = x

        q_x_padded = self.model.quantize_input(x_padded)

        q_x = q_x_padded[:, input_slice]
        
        q_x_inputs = [None for _ in range(self.nb_inputs)]
        q_x_inputs[input_index] = q_x

        # Encrypt the values
        q_x_enc = self.client.encrypt(*q_x_inputs)

        # Serialize the encrypted values to be sent to the server
        q_x_enc_ser = q_x_enc[input_index].serialize()
        return q_x_enc_ser
    

class MultiInputsFHEModelServer(FHEModelServer):

    def run(
        self,
        *serialized_encrypted_quantized_data: Tuple[bytes],
        serialized_evaluation_keys: bytes,
    ) -> bytes:
        """Run the model on the server over encrypted data for a multi-party model.

        Args:
            serialized_encrypted_quantized_data (Tuple[bytes]): The encrypted, quantized
                and serialized data for a multi-party model.
            serialized_evaluation_keys (bytes): The serialized evaluation key.

        Returns:
            bytes: the result of the model
        """
        assert self.server is not None, "Model has not been loaded."

        deserialized_encrypted_quantized_data = tuple(Value.deserialize(data) for data in serialized_encrypted_quantized_data)

        deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)

        result = self.server.run(
            *deserialized_encrypted_quantized_data, evaluation_keys=deserialized_evaluation_keys
        )
        serialized_result = result.serialize()
        return serialized_result