potential_people / predict.py
MotoPanda's picture
Update predict.py
d109898 verified
from catboost import CatBoostRegressor
from heuristic import HeuristicRegressor
import numpy as np
import sys
import re
import logging
import json
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class CatBoostPredictor:
def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1):
self.model = CatBoostRegressor().load_model(model_path, format="json")
with open(fallback_model_path) as f:
fallback_model_json = json.load(f)
self.fallback_model = HeuristicRegressor.from_json(fallback_model_json)
self.fallback_threshold=fallback_threshold
def predict(self, wage: float, age: float, business_segment: str):
X = np.array([np.log(wage), age, business_segment]).reshape((1, -1))
boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0]
# If model is wrong, fallback to heuristic
# !! Think of better method later
# Also need analytics for fallback_threshold, defaulting to 1.1 for now
if (boosting_predict < self.fallback_threshold * wage):
logging.debug("Falling back to heuristic")
return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment}))
return boosting_predict