Spaces:
Runtime error
Runtime error
print('INFO: import modules') | |
import base64 | |
from flask import Flask, request | |
import json | |
import pickle | |
import numpy as np | |
import os | |
from required_classes import * | |
CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3} | |
print('INFO: loading models') | |
try: | |
with open('embedder/embedder.pkl', 'rb') as f: | |
embedder = pickle.load(f) | |
print('INFO: embedder loaded') | |
except Exception as e: | |
print(f"ERROR: loading embedder failed with: {str(e)}") | |
print('Loading classifiers of codes') | |
classifiers_codes = {} | |
try: | |
for clf_name in os.listdir('classifiers/codes'): | |
if '.' == clf_name[0]: | |
continue | |
with open('classifiers/codes/'+clf_name, 'rb') as f: | |
model = pickle.load(f) | |
classifiers_codes[clf_name.split('.')[0]] = model | |
print(f'INFO: codes classifier {clf_name} loaded') | |
except Exception as e: | |
print(f"ERROR: loading classifiers failed with: {str(e)}") | |
print('Loading classifiers of groups') | |
classifiers_groups = {} | |
try: | |
for clf_name in os.listdir('classifiers/groups'): | |
if '.' == clf_name[0]: | |
continue | |
with open('classifiers/groups/'+clf_name, 'rb') as f: | |
model = pickle.load(f) | |
classifiers_groups[clf_name.split('.')[0]] = model | |
print(f'INFO: groups classifier {clf_name} loaded') | |
except Exception as e: | |
print(f"ERROR: loading classifiers failed with: {str(e)}") | |
print('Loading classifiers in groups') | |
groups_models = {} | |
try: | |
for clf_name in os.listdir('classifiers/codes_in_groups'): | |
if '.' == clf_name[0]: | |
continue | |
with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f: | |
model = pickle.load(f) | |
group_name = clf_name.replace('_code_clf.pkl', '') | |
groups_models[group_name] = model | |
print(f'INFO: codes classifier for group {group_name} loaded') | |
except Exception as e: | |
print(f"ERROR: loading classifiers failed with: {str(e)}") | |
def classify_code(text, top_n): | |
embed = [embedder(text)] | |
preds = {} | |
for clf_name in classifiers_codes.keys(): | |
model = classifiers_codes[clf_name] | |
probs = model.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} | |
preds[clf_name] = clf_preds | |
return preds | |
def classify_group(text, top_n): | |
embed = [embedder(text)] | |
preds = {} | |
for clf_name in classifiers_groups.keys(): | |
model = classifiers_groups[clf_name] | |
probs = model.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} | |
preds[clf_name] = clf_preds | |
return preds | |
def classify_code_by_group(text, group_name, top_n): | |
embed = [embedder(text)] | |
model = groups_models[group_name] | |
probs = model.predict_proba(embed) | |
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) | |
top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} | |
top_cls = model.classes_[best_n[0]] | |
all_codes_in_group = model.classes_ | |
return top_cls, top_n_preds, all_codes_in_group | |
def get_top_result(preds): | |
total_scores = {} | |
for clf_name, scores in preds.items(): | |
clf_name = clf_name.replace('_codes', '').replace('_groups', '') | |
for class_name, score in scores.items(): | |
if class_name in total_scores: | |
total_scores[class_name] += CLS_WEIGHTS[clf_name] * score | |
else: | |
total_scores[class_name] = CLS_WEIGHTS[clf_name] * score | |
max_idx = np.array(total_scores.values()).argmax() | |
if list(total_scores.values())[max_idx] > 0.5: | |
return list(total_scores.keys())[max_idx] | |
else: | |
return None | |
app = Flask(__name__) | |
def test_get(): | |
return {'hello': 'world'} | |
def test(): | |
data = request.json | |
return {'response': data} | |
def predict_api(): | |
data = request.json | |
base64_bytes = str(data['textB64']).encode("ascii") | |
sample_string_bytes = base64.b64decode(base64_bytes) | |
text = sample_string_bytes.decode("ascii") | |
top_n = int(data['top_n']) | |
if top_n < 1: | |
return {'error': 'top_n should be geather than 0'} | |
if text.strip() == '': | |
return {'error': 'text is empty'} | |
pred_codes = classify_code(text, top_n) | |
pred_groups = classify_group(text, top_n) | |
pred_codes_top = get_top_result(pred_codes) | |
pred_groups_top = get_top_result(pred_groups) | |
message_codes = 'models agree' if pred_codes_top is not None else 'models disagree' | |
message_group = 'models agree' if pred_groups_top is not None else 'models disagree' | |
result = { | |
"icd10": | |
{'result': pred_codes_top, 'details': pred_codes, 'message': message_codes}, | |
"dx_group": | |
{'result': pred_groups_top, 'details': pred_groups, 'message': message_group}, | |
} | |
return result | |
def predict_code_api(): | |
data = request.json | |
base64_bytes = str(data['textB64']).encode("ascii") | |
sample_string_bytes = base64.b64decode(base64_bytes) | |
text = sample_string_bytes.decode("ascii") | |
top_n = int(data['top_n']) | |
group_name = data['dx_group'] | |
if top_n < 1: | |
return {'error': 'top_n should be geather than 0'} | |
if text.strip() == '': | |
return {'error': 'text is empty'} | |
if group_name not in groups_models: | |
return {'error': 'have no classifier for the group'} | |
top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n) | |
result = { | |
"icd10": | |
{ | |
'result': top_pred_code, | |
'probability': pred_codes[top_pred_code], | |
'details': pred_codes, | |
'all_codes': all_codes_in_group | |
} | |
} | |
return result | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=7860) | |