print('INFO: import modules') import base64 from flask import Flask, request import json import pickle import numpy as np import os from required_classes import BertEmbedder, PredictModel CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3} print('INFO: loading models') try: with open('embedder/embedder.pkl', 'rb') as f: embedder = pickle.load(f) print('INFO: embedder loaded') except Exception as e: print(f"ERROR: loading embedder failed with: {str(e)}") classifiers_codes = {} try: for clf_name in os.listdir('classifiers/codes'): if '.' == clf_name[0]: continue with open('classifiers/codes/'+clf_name, 'rb') as f: model = pickle.load(f) classifiers_codes[clf_name.split('.')[0]] = model print(f'INFO: classifier {clf_name} loaded') except Exception as e: print(f"ERROR: loading classifiers failed with: {str(e)}") classifiers_groups = {} try: for clf_name in os.listdir('classifiers/groups'): if '.' == clf_name[0]: continue with open('classifiers/groups/'+clf_name, 'rb') as f: model = pickle.load(f) classifiers_groups[clf_name.split('.')[0]] = model print(f'INFO: classifier {clf_name} loaded') except Exception as e: print(f"ERROR: loading classifiers failed with: {str(e)}") def classify_code(text, top_n): embed = [embedder(text)] preds = {} for clf_name in classifiers_codes.keys(): model = classifiers_codes[clf_name] probs = model.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} preds[clf_name] = clf_preds return preds def classify_group(text, top_n): embed = [embedder(text)] preds = {} for clf_name in classifiers_groups.keys(): model = classifiers_groups[clf_name] probs = model.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n} preds[clf_name] = clf_preds return preds def get_top_result(preds): total_scores = {} for clf_name, scores in preds.items(): clf_name = clf_name.replace('_codes', '').replace('_groups', '') for class_name, score in scores.items(): if class_name in total_scores: total_scores[class_name] += CLS_WEIGHTS[clf_name] * score else: total_scores[class_name] = CLS_WEIGHTS[clf_name] * score max_idx = np.array(total_scores.values()).argmax() if list(total_scores.values())[max_idx] > 0.5: return list(total_scores.keys())[max_idx] else: return None app = Flask(__name__) @app.get("/") def test_get(): return {'hello': 'world'} @app.route("/test", methods=['POST']) def test(): data = request.json return {'response': data} @app.route("/predict", methods=['POST']) def read_root(): data = request.json base64_bytes = str(data['textB64']).encode("ascii") sample_string_bytes = base64.b64decode(base64_bytes) text = sample_string_bytes.decode("ascii") top_n = int(data['top_n']) if top_n < 1: return {'error': 'top_n should be geather than 0'} if text.strip() == '': return {'error': 'text is empty'} pred_codes = classify_code(text, top_n) pred_groups = classify_group(text, top_n) pred_codes_top = get_top_result(pred_codes) pred_groups_top = get_top_result(pred_groups) result = { "icd10": {'result': pred_codes_top, 'details': pred_codes}, "dx_group": {'result': pred_groups_top, 'details': pred_groups} } return result if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)