potential_people / predict.py
MotoPanda's picture
Update predict.py
d109898 verified
raw
history blame
1.32 kB
from catboost import CatBoostRegressor
from heuristic import HeuristicRegressor
import numpy as np
import sys
import re
import logging
import json
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class CatBoostPredictor:
def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1):
self.model = CatBoostRegressor().load_model(model_path, format="json")
with open(fallback_model_path) as f:
fallback_model_json = json.load(f)
self.fallback_model = HeuristicRegressor.from_json(fallback_model_json)
self.fallback_threshold=fallback_threshold
def predict(self, wage: float, age: float, business_segment: str):
X = np.array([np.log(wage), age, business_segment]).reshape((1, -1))
boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0]
# If model is wrong, fallback to heuristic
# !! Think of better method later
# Also need analytics for fallback_threshold, defaulting to 1.1 for now
if (boosting_predict < self.fallback_threshold * wage):
logging.debug("Falling back to heuristic")
return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment}))
return boosting_predict