|
"""Modified model class to handles multi-inputs circuit.""" |
|
|
|
import numpy |
|
import time |
|
from typing import Optional, Sequence, Union |
|
|
|
from concrete.fhe.compilation.compiler import Compiler, Configuration, DebugArtifacts, Circuit |
|
|
|
from concrete.ml.common.check_inputs import check_array_and_assert |
|
from concrete.ml.common.utils import ( |
|
generate_proxy_function, |
|
manage_parameters_for_pbs_errors, |
|
check_there_is_no_p_error_options_in_configuration |
|
) |
|
from concrete.ml.quantization.quantized_module import QuantizedModule, _get_inputset_generator |
|
from concrete.ml.sklearn import DecisionTreeClassifier, DecisionTreeRegressor |
|
|
|
class MultiInputModel: |
|
|
|
def quantize_input(self, *X: numpy.ndarray) -> numpy.ndarray: |
|
self.check_model_is_fitted() |
|
assert sum(input.shape[1] for input in X) == len(self.input_quantizers) |
|
|
|
base_j = 0 |
|
q_inputs = [] |
|
for i, input in enumerate(X): |
|
q_input = numpy.zeros_like(input, dtype=numpy.int64) |
|
|
|
for j in range(input.shape[1]): |
|
quantizer_index = base_j + j |
|
q_input[:, j] = self.input_quantizers[quantizer_index].quant(input[:, j]) |
|
|
|
assert q_input.dtype == numpy.int64, f"Inputs {i} were not quantized to int64 values" |
|
|
|
q_inputs.append(q_input) |
|
base_j += input.shape[1] |
|
|
|
return tuple(q_inputs) if len(q_inputs) > 1 else q_inputs[0] |
|
|
|
def compile( |
|
self, |
|
*inputs, |
|
configuration: Optional[Configuration] = None, |
|
artifacts: Optional[DebugArtifacts] = None, |
|
show_mlir: bool = False, |
|
p_error: Optional[float] = None, |
|
global_p_error: Optional[float] = None, |
|
verbose: bool = False, |
|
inputs_encryption_status: Optional[Sequence[str]] = None, |
|
) -> Circuit: |
|
|
|
|
|
self.check_model_is_fitted() |
|
|
|
|
|
inputs_as_array = [] |
|
for input in inputs: |
|
input_as_array = check_array_and_assert(input) |
|
inputs_as_array.append(input_as_array) |
|
|
|
inputs_as_array = tuple(inputs_as_array) |
|
|
|
|
|
check_there_is_no_p_error_options_in_configuration(configuration) |
|
|
|
|
|
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error) |
|
|
|
|
|
quantized_inputs = self.quantize_input(*inputs_as_array) |
|
|
|
|
|
inputset = _get_inputset_generator(quantized_inputs) |
|
|
|
|
|
self._is_compiled = False |
|
|
|
|
|
module_to_compile = self._get_module_to_compile(inputs_encryption_status) |
|
|
|
|
|
assert isinstance(module_to_compile, Compiler), ( |
|
"Wrong module to compile. Expected to be of type `Compiler` but got " |
|
f"{type(module_to_compile)}." |
|
) |
|
|
|
|
|
|
|
self.fhe_circuit_ = module_to_compile.compile( |
|
inputset, |
|
configuration=configuration, |
|
artifacts=artifacts, |
|
show_mlir=show_mlir, |
|
p_error=p_error, |
|
global_p_error=global_p_error, |
|
verbose=verbose, |
|
single_precision=False, |
|
fhe_simulation=False, |
|
fhe_execution=True, |
|
jit=False, |
|
) |
|
|
|
self._is_compiled = True |
|
|
|
|
|
assert isinstance(self.fhe_circuit, Circuit) |
|
|
|
return self.fhe_circuit |
|
|
|
def _get_module_to_compile(self, inputs_encryption_status) -> Union[Compiler, QuantizedModule]: |
|
assert self._tree_inference is not None, self._is_not_fitted_error_message() |
|
|
|
if not self._is_compiled: |
|
xgb_inference = self._tree_inference |
|
self._tree_inference = lambda *args: xgb_inference(numpy.concatenate(args, axis=1)) |
|
|
|
input_names = [f"input_{i}_encrypted" for i in range(len(inputs_encryption_status))] |
|
|
|
|
|
_tree_inference_proxy, function_arg_names = generate_proxy_function( |
|
self._tree_inference, input_names |
|
) |
|
|
|
inputs_encryption_statuses = {input_name: status for input_name, status in zip(function_arg_names.values(), inputs_encryption_status)} |
|
|
|
|
|
compiler = Compiler( |
|
_tree_inference_proxy, |
|
inputs_encryption_statuses, |
|
) |
|
|
|
return compiler |
|
|
|
class MultiInputDecisionTreeClassifier(MultiInputModel, DecisionTreeClassifier): |
|
pass |
|
|
|
class MultiInputDecisionTreeRegressor(MultiInputModel, DecisionTreeRegressor): |
|
pass |
|
|