encrypted_credit_scoring / utils /pre_processing.py
romanbredehoft-zama's picture
Update to synthetic data-set
18ba8c1
raw
history blame
2.4 kB
"""Data pre-processing functions."""
import numpy
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, FunctionTransformer, StandardScaler
def _get_pipeline_replace_one_hot(func, value):
return Pipeline([
("replace", FunctionTransformer(
func,
kw_args={"value": value},
feature_names_out='one-to-one',
)),
("one_hot", OneHotEncoder(),),
])
def _replace_values_eq(column, value):
for desired_value, values_to_replace in value.items():
column = numpy.where(numpy.isin(column, values_to_replace), desired_value, column)
return column
def get_pre_processors():
pre_processor_user = ColumnTransformer(
transformers=[
(
"replace_occupation_type_labor",
_get_pipeline_replace_one_hot(
_replace_values_eq,
{
"Labor_work": [
"Cooking Staff", "Carpenter", "Plumber", "Factory Worker", "Bus Driver"
],
"Office_work": [
"Business Owners", "Office Worker", "Accountant", "Entrepreneur", "Salesperson"
],
"High_tech_work": ["Engineer", "Manager", "Consultant", "Software Developer"],
},
),
['Occupation_type']
),
('one_hot_others', OneHotEncoder(), ['Housing_type', 'Family_status', 'Education_type', 'Income_type']),
('standard_scaler', StandardScaler(), ['Num_children', 'Household_size', 'Total_income', 'Age']),
],
remainder='passthrough',
verbose_feature_names_out=False,
)
pre_processor_bank = ColumnTransformer(
transformers=[
('standard_scaler', StandardScaler(), ['Account_age']),
],
remainder='passthrough',
verbose_feature_names_out=False,
)
pre_processor_third_party = ColumnTransformer(
transformers=[
('standard_scaler', StandardScaler(), ['Years_employed']),
],
remainder='passthrough',
verbose_feature_names_out=False,
)
return pre_processor_user, pre_processor_bank, pre_processor_third_party