from catboost import CatBoostRegressor from heuristic import HeuristicRegressor import numpy as np import sys import re import logging import json logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) class CatBoostPredictor: def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1): self.model = CatBoostRegressor().load_model(model_path, format="json") with open(fallback_model_path) as f: fallback_model_json = json.load(f) self.fallback_model = HeuristicRegressor.from_json(fallback_model_json) self.fallback_threshold=fallback_threshold def predict(self, wage: float, age: float, business_segment: str): X = np.array([np.log(wage), age, business_segment]).reshape((1, -1)) boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0] # If model is wrong, fallback to heuristic # !! Think of better method later # Also need analytics for fallback_threshold, defaulting to 1.1 for now if (boosting_predict < self.fallback_threshold * wage): logging.debug("Falling back to heuristic") return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment})) return boosting_predict