print('INFO: import modules') from flask import Flask, request import json import pickle import numpy as np from required_classes import BertEmbedder, PredictModel print('INFO: loading model') try: with open('model_finetuned_clear.pkl', 'rb') as f: model = pickle.load(f) model.batch_size = 1 print('INFO: model loaded') except Exception as e: print(f"ERROR: loading models failed with: {str(e)}") def classify_code(text, top_n): embed = model._texts2vecs([text]) probs = model.classifier_code.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) preds = [{'code': model.classifier_code.classes_[i], 'proba': probs[0][i]} for i in best_n] return preds def classify_group(text, top_n): embed = model._texts2vecs([text]) probs = model.classifier_group.predict_proba(embed) best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) preds = [{'group': model.classifier_group.classes_[i], 'proba': probs[0][i]} for i in best_n] return preds app = Flask(__name__) @app.get("/") def test_get(): return {'hello': 'world'} @app.route("/test", methods=['POST']) def test(): data = request.json return {'response': data} @app.route("/predict", methods=['POST']) def read_root(): data = request.json print(data) text = str(data['text']) top_n = int(data['top_n']) if top_n < 1: return {'error': 'top_n should be geather than 0'} if text.strip() == '': return {'error': 'text is empty'} pred_codes = classify_code(text, top_n) pred_groups = classify_group(text, top_n) result = { "icd10": {'result': pred_codes[0]['code'], 'details': pred_codes}, "dx_group": {'result': pred_groups[0]['group'], 'details': pred_groups} } return result if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)