|
print('INFO: import modules') |
|
from flask import Flask, request |
|
import json |
|
import pickle |
|
import numpy as np |
|
|
|
from required_classes import BertEmbedder, PredictModel |
|
|
|
|
|
print('INFO: loading model') |
|
try: |
|
with open('model_finetuned_clear.pkl', 'rb') as f: |
|
model = pickle.load(f) |
|
model.batch_size = 1 |
|
print('INFO: model loaded') |
|
except Exception as e: |
|
print(f"ERROR: loading models failed with: {str(e)}") |
|
|
|
def classify_code(text, top_n): |
|
embed = model._texts2vecs([text]) |
|
probs = model.classifier_code.predict_proba(embed) |
|
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) |
|
preds = [{'code': model.classifier_code.classes_[i], 'proba': probs[0][i]} for i in best_n] |
|
return preds |
|
|
|
def classify_group(text, top_n): |
|
embed = model._texts2vecs([text]) |
|
probs = model.classifier_group.predict_proba(embed) |
|
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:]) |
|
preds = [{'group': model.classifier_group.classes_[i], 'proba': probs[0][i]} for i in best_n] |
|
return preds |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
@app.get("/") |
|
def test_get(): |
|
return {'hello': 'world'} |
|
|
|
@app.route("/test", methods=['POST']) |
|
def test(): |
|
data = request.json |
|
return {'response': data} |
|
|
|
@app.route("/predict", methods=['POST']) |
|
def read_root(): |
|
data = request.json |
|
print(data) |
|
text = str(data['text']) |
|
top_n = int(data['top_n']) |
|
|
|
if top_n < 1: |
|
return {'error': 'top_n should be geather than 0'} |
|
if text.strip() == '': |
|
return {'error': 'text is empty'} |
|
|
|
pred_codes = classify_code(text, top_n) |
|
pred_groups = classify_group(text, top_n) |
|
result = { |
|
"icd10": |
|
{'result': pred_codes[0]['code'], 'details': pred_codes}, |
|
"dx_group": |
|
{'result': pred_groups[0]['group'], 'details': pred_groups} |
|
} |
|
return result |
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=7860) |
|
|