Spaces:
Runtime error
Runtime error
File size: 1,904 Bytes
6304a81 4a7adf1 6304a81 3a4727b 4a7adf1 6304a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
print('INFO: import modules')
from flask import Flask, request
import json
import pickle
import numpy as np
from required_classes import BertEmbedder, PredictModel
print('INFO: loading model')
try:
with open('model_finetuned_clear.pkl', 'rb') as f:
model = pickle.load(f)
model.batch_size = 1
print('INFO: model loaded')
except Exception as e:
print(f"ERROR: loading models failed with: {str(e)}")
def classify_code(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_code.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = [{'code': model.classifier_code.classes_[i], 'proba': probs[0][i]} for i in best_n]
return preds
def classify_group(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_group.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = [{'group': model.classifier_group.classes_[i], 'proba': probs[0][i]} for i in best_n]
return preds
app = Flask(__name__)
@app.get("/")
def test_get():
return {'hello': 'world'}
@app.route("/test", methods=['POST'])
def test():
data = request.json
return {'response': data}
@app.route("/predict", methods=['POST'])
def read_root():
data = request.json
print(data)
text = str(data['text'])
top_n = int(data['top_n'])
if top_n < 1:
return {'error': 'top_n should be geather than 0'}
if text.strip() == '':
return {'error': 'text is empty'}
pred_codes = classify_code(text, top_n)
pred_groups = classify_group(text, top_n)
result = {
"icd10":
{'result': pred_codes[0]['code'], 'details': pred_codes},
"dx_group":
{'result': pred_groups[0]['group'], 'details': pred_groups}
}
return result
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)
|