print('INFO: import modules') import base64 from flask import Flask, request import json import pickle import numpy as np import os from required_classes import * CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3} print('INFO: loading models') try: with open('embedder/embedder.pkl', 'rb') as f: embedder = pickle.load(f) embedder.embedder = embedder.embedder.bert print('INFO: embedder loaded') except Exception as e: print(f"ERROR: loading embedder failed with: {str(e)}") raise e print('Loading classifiers of codes') classifiers_codes = {} try: for clf_name in os.listdir('classifiers/codes'): if '.' == clf_name[0]: continue with open('classifiers/codes/'+clf_name, 'rb') as f: model = pickle.load(f) classifiers_codes[clf_name.split('.')[0]] = model print(f'INFO: codes classifier {clf_name} loaded') except Exception as e: print(f"ERROR: loading classifiers failed with: {str(e)}") raise e print('Loading classifiers of groups') classifiers_groups = {} try: for clf_name in os.listdir('classifiers/groups'): if '.' == clf_name[0]: continue with open('classifiers/groups/'+clf_name, 'rb') as f: model = pickle.load(f) classifiers_groups[clf_name.split('.')[0]] = model print(f'INFO: groups classifier {clf_name} loaded') except Exception as e: print(f"ERROR: loading classifiers failed with: {str(e)}") raise e print('Loading classifiers in groups') groups_models = {} try: for clf_name in os.listdir('classifiers/codes_in_groups'): if '.' == clf_name[0]: continue with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f: model = pickle.load(f) group_name = clf_name.replace('_code_clf.pkl', '') groups_models[group_name] = model print(f'INFO: codes classifier for group {group_name} loaded') except Exception as e: print(f"ERROR: loading classifiers failed with: {str(e)}") raise e def classify_code(text, top_n): embed = [embedder(text)] preds = {} for clf_name in classifiers_codes.keys(): print(clf_name) model = classifiers_codes[clf_name] probs = model.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} preds[clf_name] = clf_preds return preds def classify_group(text, top_n): embed = [embedder(text)] preds = {} for clf_name in classifiers_groups.keys(): model = classifiers_groups[clf_name] probs = model.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} preds[clf_name] = clf_preds return preds def classify_code_by_group(text, group_name, top_n): embed = [embedder(text)] model = groups_models[group_name] probs = model.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} top_cls = model.classes_[best_n[0]] all_codes_in_group = model.classes_ return top_cls, top_n_preds, all_codes_in_group def get_top_result(preds): total_scores = {} for clf_name, scores in preds.items(): clf_name = clf_name.replace('_codes', '').replace('_groups', '') for class_name, score in scores.items(): if class_name in total_scores: total_scores[class_name] += CLS_WEIGHTS[clf_name] * score else: total_scores[class_name] = CLS_WEIGHTS[clf_name] * score max_idx = np.array(total_scores.values()).argmax() if list(total_scores.values())[max_idx] > 0.5: return list(total_scores.keys())[max_idx] else: return None app = Flask(__name__) @app.get("/") def test_get(): return {'hello': 'world'} @app.route("/test", methods=['POST']) def test(): data = request.json return {'response': data} @app.route("/predict", methods=['POST']) def predict_api(): data = request.json base64_bytes = str(data['textB64']).encode("ascii") sample_string_bytes = base64.b64decode(base64_bytes) text = sample_string_bytes.decode("ascii") top_n = int(data['top_n']) if top_n < 1: return {'error': 'top_n should be geather than 0'} if text.strip() == '': return {'error': 'text is empty'} pred_codes = classify_code(text, top_n) pred_groups = classify_group(text, top_n) pred_codes_top = get_top_result(pred_codes) pred_groups_top = get_top_result(pred_groups) message_codes = 'models agree' if pred_codes_top is not None else 'models disagree' message_group = 'models agree' if pred_groups_top is not None else 'models disagree' result = { "icd10": {'result': pred_codes_top, 'details': pred_codes, 'message': message_codes}, "dx_group": {'result': pred_groups_top, 'details': pred_groups, 'message': message_group}, } return result @app.route("/predict_code", methods=['POST']) def predict_code_api(): data = request.json base64_bytes = str(data['textB64']).encode("ascii") sample_string_bytes = base64.b64decode(base64_bytes) text = sample_string_bytes.decode("ascii") top_n = int(data['top_n']) group_name = data['dx_group'] if top_n < 1: return {'error': 'top_n should be geather than 0'} if text.strip() == '': return {'error': 'text is empty'} if group_name not in groups_models: return {'error': 'have no classifier for the group'} top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n) result = { "icd10": { 'result': top_pred_code, 'probability': pred_codes[top_pred_code], 'details': pred_codes, 'all_codes': all_codes_in_group } } return result if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)