opdx / app.py
lyangas's picture
Update app.py
88a4631 verified
print('INFO: import modules')
import base64
from flask import Flask, request
import json
import pickle
import numpy as np
import os
from required_classes import *
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3}
print('INFO: loading models')
try:
with open('embedder/embedder.pkl', 'rb') as f:
embedder = pickle.load(f)
embedder.embedder = embedder.embedder.bert
print('INFO: embedder loaded')
except Exception as e:
print(f"ERROR: loading embedder failed with: {str(e)}")
raise e
print('Loading classifiers of codes')
classifiers_codes = {}
try:
for clf_name in os.listdir('classifiers/codes'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_codes[clf_name.split('.')[0]] = model
print(f'INFO: codes classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
raise e
print('Loading classifiers of groups')
classifiers_groups = {}
try:
for clf_name in os.listdir('classifiers/groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_groups[clf_name.split('.')[0]] = model
print(f'INFO: groups classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
raise e
print('Loading classifiers in groups')
groups_models = {}
try:
for clf_name in os.listdir('classifiers/codes_in_groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
group_name = clf_name.replace('_code_clf.pkl', '')
groups_models[group_name] = model
print(f'INFO: codes classifier for group {group_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
raise e
def classify_code(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_codes.keys():
print(clf_name)
model = classifiers_codes[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_group(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_groups.keys():
model = classifiers_groups[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_code_by_group(text, group_name, top_n):
embed = [embedder(text)]
model = groups_models[group_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
top_cls = model.classes_[best_n[0]]
all_codes_in_group = model.classes_
return top_cls, top_n_preds, all_codes_in_group
def get_top_result(preds):
total_scores = {}
for clf_name, scores in preds.items():
clf_name = clf_name.replace('_codes', '').replace('_groups', '')
for class_name, score in scores.items():
if class_name in total_scores:
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score
else:
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score
max_idx = np.array(total_scores.values()).argmax()
if list(total_scores.values())[max_idx] > 0.5:
return list(total_scores.keys())[max_idx]
else:
return None
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.json
return {'response': data}
@app.route("/predict", methods=['POST'])
def predict_api():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
pred_codes_top = get_top_result(pred_codes)
pred_groups_top = get_top_result(pred_groups)
message_codes = 'models agree' if pred_codes_top is not None else 'models disagree'
message_group = 'models agree' if pred_groups_top is not None else 'models disagree'
result = {
"icd10":
{'result': pred_codes_top, 'details': pred_codes, 'message': message_codes},
"dx_group":
{'result': pred_groups_top, 'details': pred_groups, 'message': message_group},
}
return result
@app.route("/predict_code", methods=['POST'])
def predict_code_api():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
group_name = data['dx_group']
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
if group_name not in groups_models:
return {'error': 'have no classifier for the group'}
top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n)
result = {
"icd10":
{
'result': top_pred_code,
'probability': pred_codes[top_pred_code],
'details': pred_codes,
'all_codes': all_codes_in_group
}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)