Spaces:
Runtime error
Runtime error
File size: 3,858 Bytes
6304a81 d5e3dc4 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 4a7adf1 6304a81 3a4727b 4a7adf1 afc3da6 d5e3dc4 6304a81 da7535b 6304a81 da7535b 6304a81 da7535b 6304a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
print('INFO: import modules')
import base64
from flask import Flask, request
import json
import pickle
import numpy as np
import os
from required_classes import BertEmbedder, PredictModel
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3}
print('INFO: loading models')
try:
with open('embedder/embedder.pkl', 'rb') as f:
embedder = pickle.load(f)
print('INFO: embedder loaded')
except Exception as e:
print(f"ERROR: loading embedder failed with: {str(e)}")
classifiers_codes = {}
try:
for clf_name in os.listdir('classifiers/codes'):
if '.' == clf_name[0]:
continue
with open('classifiers/codes/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_codes[clf_name.split('.')[0]] = model
print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
classifiers_groups = {}
try:
for clf_name in os.listdir('classifiers/groups'):
if '.' == clf_name[0]:
continue
with open('classifiers/groups/'+clf_name, 'rb') as f:
model = pickle.load(f)
classifiers_groups[clf_name.split('.')[0]] = model
print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
print(f"ERROR: loading classifiers failed with: {str(e)}")
def classify_code(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_codes.keys():
model = classifiers_codes[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def classify_group(text, top_n):
embed = [embedder(text)]
preds = {}
for clf_name in classifiers_groups.keys():
model = classifiers_groups[clf_name]
probs = model.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
preds[clf_name] = clf_preds
return preds
def get_top_result(preds):
total_scores = {}
for clf_name, scores in preds.items():
clf_name = clf_name.replace('_codes', '').replace('_groups', '')
for class_name, score in scores.items():
if class_name in total_scores:
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score
else:
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score
max_idx = np.array(total_scores.values()).argmax()
if list(total_scores.values())[max_idx] > 0.5:
return list(total_scores.keys())[max_idx]
else:
return None
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.json
return {'response': data}
@app.route("/predict", methods=['POST'])
def read_root():
data = request.json
base64_bytes = str(data['textB64']).encode("ascii")
sample_string_bytes = base64.b64decode(base64_bytes)
text = sample_string_bytes.decode("ascii")
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
pred_codes_top = get_top_result(pred_codes)
pred_groups_top = get_top_result(pred_groups)
result = {
"icd10":
{'result': pred_codes_top, 'details': pred_codes},
"dx_group":
{'result': pred_groups_top, 'details': pred_groups}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)
|