Spaces:
Runtime error
Runtime error
print('INFO: import modules') | |
import base64 | |
from flask import Flask, request | |
import json | |
import pickle | |
import numpy as np | |
import os | |
from required_classes import BertEmbedder, PredictModel | |
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3} | |
print('INFO: loading models') | |
try: | |
with open('embedder/embedder.pkl', 'rb') as f: | |
embedder = pickle.load(f) | |
print('INFO: embedder loaded') | |
except Exception as e: | |
print(f"ERROR: loading embedder failed with: {str(e)}") | |
classifiers_codes = {} | |
try: | |
for clf_name in os.listdir('classifiers/codes'): | |
if '.' == clf_name[0]: | |
continue | |
with open('classifiers/codes/'+clf_name, 'rb') as f: | |
model = pickle.load(f) | |
classifiers_codes[clf_name.split('.')[0]] = model | |
print(f'INFO: classifier {clf_name} loaded') | |
except Exception as e: | |
print(f"ERROR: loading classifiers failed with: {str(e)}") | |
classifiers_groups = {} | |
try: | |
for clf_name in os.listdir('classifiers/groups'): | |
if '.' == clf_name[0]: | |
continue | |
with open('classifiers/groups/'+clf_name, 'rb') as f: | |
model = pickle.load(f) | |
classifiers_groups[clf_name.split('.')[0]] = model | |
print(f'INFO: classifier {clf_name} loaded') | |
except Exception as e: | |
print(f"ERROR: loading classifiers failed with: {str(e)}") | |
def classify_code(text, top_n): | |
embed = [embedder(text)] | |
preds = {} | |
for clf_name in classifiers_codes.keys(): | |
model = classifiers_codes[clf_name] | |
probs = model.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} | |
preds[clf_name] = clf_preds | |
return preds | |
def classify_group(text, top_n): | |
embed = [embedder(text)] | |
preds = {} | |
for clf_name in classifiers_groups.keys(): | |
model = classifiers_groups[clf_name] | |
probs = model.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} | |
preds[clf_name] = clf_preds | |
return preds | |
def get_top_result(preds): | |
total_scores = {} | |
for clf_name, scores in preds.items(): | |
clf_name = clf_name.replace('_codes', '').replace('_groups', '') | |
for class_name, score in scores.items(): | |
if class_name in total_scores: | |
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score | |
else: | |
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score | |
max_idx = np.array(total_scores.values()).argmax() | |
if list(total_scores.values())[max_idx] > 0.5: | |
return list(total_scores.keys())[max_idx] | |
else: | |
return None | |
app = Flask(__name__) | |
def test_get(): | |
return {'hello': 'world'} | |
def test(): | |
data = request.json | |
return {'response': data} | |
def read_root(): | |
data = request.json | |
base64_bytes = str(data['textB64']).encode("ascii") | |
sample_string_bytes = base64.b64decode(base64_bytes) | |
text = sample_string_bytes.decode("ascii") | |
top_n = int(data['top_n']) | |
if top_n < 1: | |
return {'error': 'top_n should be geather than 0'} | |
if text.strip() == '': | |
return {'error': 'text is empty'} | |
pred_codes = classify_code(text, top_n) | |
pred_groups = classify_group(text, top_n) | |
pred_codes_top = get_top_result(pred_codes) | |
pred_groups_top = get_top_result(pred_groups) | |
result = { | |
"icd10": | |
{'result': pred_codes_top, 'details': pred_codes}, | |
"dx_group": | |
{'result': pred_groups_top, 'details': pred_groups} | |
} | |
return result | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=7860) | |