Spaces:
Sleeping
Sleeping
from catboost import CatBoostRegressor | |
from heuristic import HeuristicRegressor | |
import numpy as np | |
import sys | |
import re | |
import logging | |
import json | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class CatBoostPredictor: | |
def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1): | |
self.model = CatBoostRegressor().load_model(model_path, format="json") | |
with open(fallback_model_path) as f: | |
fallback_model_json = json.load(f) | |
self.fallback_model = HeuristicRegressor.from_json(fallback_model_json) | |
self.fallback_threshold=fallback_threshold | |
def predict(self, wage: float, age: float, business_segment: str): | |
X = np.array([np.log(wage), age, business_segment]).reshape((1, -1)) | |
boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0] | |
# If model is wrong, fallback to heuristic | |
# !! Think of better method later | |
# Also need analytics for fallback_threshold, defaulting to 1.1 for now | |
if (boosting_predict < self.fallback_threshold * wage): | |
logging.debug("Falling back to heuristic") | |
return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment})) | |
return boosting_predict |