MotoPanda commited on
Commit
b797b5c
·
verified ·
1 Parent(s): 3cad32c

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +32 -0
predict.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from catboost import CatBoostRegressor
2
+ from .heuristic import HeuristicRegressor
3
+ import numpy as np
4
+ import sys
5
+ import re
6
+ import logging
7
+ import json
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logger.setLevel(logging.DEBUG)
12
+
13
+ class CatBoostPredictor:
14
+ def __init__(self, model_path, fallback_model_path, fallback_threshold=1.1):
15
+ self.model = CatBoostRegressor().load_model(model_path, format="json")
16
+
17
+ with open(fallback_model_path) as f:
18
+ fallback_model_json = json.load(f)
19
+ self.fallback_model = HeuristicRegressor.from_json(fallback_model_json)
20
+ self.fallback_threshold=fallback_threshold
21
+
22
+ def predict(self, wage: float, age: float, business_segment: str):
23
+ X = np.array([np.log(wage), age, business_segment]).reshape((1, -1))
24
+ boosting_predict = np.maximum(0.0, np.exp(self.model.predict(X)))[0]
25
+
26
+ # If model is wrong, fallback to heuristic
27
+ # !! Think of better method later
28
+ # Also need analytics for fallback_threshold, defaulting to 1.1 for now
29
+ if (boosting_predict < self.fallback_threshold * wage):
30
+ logging.debug("Falling back to heuristic")
31
+ return wage * (1 + self.fallback_model.predict({"avg_wage": wage, "age": age, "business_category": business_segment}))
32
+ return boosting_predict