opdx / app.py
lyangas
change return in /test
3a4727b
raw
history blame
1.92 kB
print('INFO: import modules')
from flask import Flask, request
import json
import pickle
import numpy as np
from required_classes import BertEmbedder, PredictModel
print('INFO: loading model')
try:
with open('model_finetuned_clear.pkl', 'rb') as f:
model = pickle.load(f)
model.batch_size = 1
print('INFO: model loaded')
except Exception as e:
print(f"ERROR: loading models failed with: {str(e)}")
def classify_code(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_code.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = [{'code': model.classifier_code.classes_[i], 'proba': probs[0][i]} for i in best_n]
return preds
def classify_group(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_group.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = [{'group': model.classifier_group.classes_[i], 'proba': probs[0][i]} for i in best_n]
return preds
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.__dict__
return {'response': data}
@app.route("/predict", methods=['POST'])
def read_root():
print(request.__dict__)
data = request.form
text = str(data['text'])
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
result = {
"icd10":
{'result': pred_codes[0]['code'], 'details': pred_codes},
"dx_group":
{'result': pred_groups[0]['group'], 'details': pred_groups}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)