icd10_docker / app.py
lyangas
add voting
da7535b
raw
history blame
3.86 kB
print('INFO: import modules')
import base64
from flask import Flask, request
import json
import pickle
import numpy as np
import os
from required_classes import BertEmbedder, PredictModel
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3}
print('INFO: loading models')
try:
with open('embedder/embedder.pkl', 'rb') as f:
embedder = pickle.load(f)
print('INFO: embedder loaded')
except Exception as e:
print(f"ERROR: loading embedder failed with: {str(e)}")
classifiers_codes = {}
try:
for clf_name in os.listdir('classifiers/codes'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_codes[clf_name.split('.')[0]] = model
print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
classifiers_groups = {}
try:
for clf_name in os.listdir('classifiers/groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_groups[clf_name.split('.')[0]] = model
print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
def classify_code(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_codes.keys():
model = classifiers_codes[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_group(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_groups.keys():
model = classifiers_groups[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def get_top_result(preds):
total_scores = {}
for clf_name, scores in preds.items():
clf_name = clf_name.replace('_codes', '').replace('_groups', '')
for class_name, score in scores.items():
if class_name in total_scores:
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score
else:
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score
max_idx = np.array(total_scores.values()).argmax()
if list(total_scores.values())[max_idx] > 0.5:
return list(total_scores.keys())[max_idx]
else:
return None
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.json
return {'response': data}
@app.route("/predict", methods=['POST'])
def read_root():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
pred_codes_top = get_top_result(pred_codes)
pred_groups_top = get_top_result(pred_groups)
result = {
"icd10":
{'result': pred_codes_top, 'details': pred_codes},
"dx_group":
{'result': pred_groups_top, 'details': pred_groups}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)