File size: 1,320 Bytes
b797b5c
d109898
b797b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from catboost import CatBoostRegressor
from heuristic import HeuristicRegressor
import numpy as np
import sys
import re
import logging
import json


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

class CatBoostPredictor:
    def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1):
        self.model = CatBoostRegressor().load_model(model_path, format="json")
        
        with open(fallback_model_path) as f:
            fallback_model_json = json.load(f)
            self.fallback_model = HeuristicRegressor.from_json(fallback_model_json)
        self.fallback_threshold=fallback_threshold

    def predict(self, wage: float, age: float, business_segment: str):
        X = np.array([np.log(wage), age, business_segment]).reshape((1, -1))
        boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0]

        # If model is wrong, fallback to heuristic
        # !! Think of better method later
        # Also need analytics for fallback_threshold, defaulting to 1.1 for now
        if (boosting_predict < self.fallback_threshold * wage):
            logging.debug("Falling back to heuristic")
            return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment}))
        return boosting_predict