Spaces:
Runtime error
Runtime error
File size: 6,084 Bytes
6304a81 d5e3dc4 6304a81 da7535b 6304a81 e1eb682 6304a81 da7535b 488bb56 da7535b 488bb56 da7535b 488bb56 da7535b 6304a81 da7535b 488bb56 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 488bb56 da7535b 6304a81 4a7adf1 6304a81 488bb56 4a7adf1 afc3da6 d5e3dc4 6304a81 da7535b 08004fd 6304a81 08004fd 6304a81 08004fd 6304a81 488bb56 08004fd 488bb56 6304a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
print('INFO: import modules')
import base64
from flask import Flask, request
import json
import pickle
import numpy as np
import os
from required_classes import *
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3}
print('INFO: loading models')
try:
with open('embedder/embedder.pkl', 'rb') as f:
embedder = pickle.load(f)
print('INFO: embedder loaded')
except Exception as e:
print(f"ERROR: loading embedder failed with: {str(e)}")
print('Loading classifiers of codes')
classifiers_codes = {}
try:
for clf_name in os.listdir('classifiers/codes'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_codes[clf_name.split('.')[0]] = model
print(f'INFO: codes classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
print('Loading classifiers of groups')
classifiers_groups = {}
try:
for clf_name in os.listdir('classifiers/groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_groups[clf_name.split('.')[0]] = model
print(f'INFO: groups classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
print('Loading classifiers in groups')
groups_models = {}
try:
for clf_name in os.listdir('classifiers/codes_in_groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
group_name = clf_name.replace('_code_clf.pkl', '')
groups_models[group_name] = model
print(f'INFO: codes classifier for group {group_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
def classify_code(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_codes.keys():
model = classifiers_codes[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_group(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_groups.keys():
model = classifiers_groups[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_code_by_group(text, group_name, top_n):
embed = [embedder(text)]
model = groups_models[group_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
top_cls = model.classes_[best_n[0]]
all_codes_in_group = model.classes_
return top_cls, top_n_preds, all_codes_in_group
def get_top_result(preds):
total_scores = {}
for clf_name, scores in preds.items():
clf_name = clf_name.replace('_codes', '').replace('_groups', '')
for class_name, score in scores.items():
if class_name in total_scores:
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score
else:
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score
max_idx = np.array(total_scores.values()).argmax()
if list(total_scores.values())[max_idx] > 0.5:
return list(total_scores.keys())[max_idx]
else:
return None
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.json
return {'response': data}
@app.route("/predict", methods=['POST'])
def predict_api():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
pred_codes_top = get_top_result(pred_codes)
pred_groups_top = get_top_result(pred_groups)
message_codes = 'models agree' if pred_codes_top is not None else 'models disagree'
message_group = 'models agree' if pred_groups_top is not None else 'models disagree'
result = {
"icd10":
{'result': pred_codes_top, 'details': pred_codes, 'message': message_codes},
"dx_group":
{'result': pred_groups_top, 'details': pred_groups, 'message': message_group},
}
return result
@app.route("/predict_code", methods=['POST'])
def predict_code_api():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
group_name = data['dx_group']
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
if group_name not in groups_models:
return {'error': 'have no classifier for the group'}
top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n)
result = {
"icd10":
{
'result': top_pred_code,
'probability': pred_codes[top_pred_code],
'details': pred_codes,
'all_codes': all_codes_in_group
}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)
|