File size: 1,904 Bytes
6304a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a7adf1
6304a81
 
 
3a4727b
4a7adf1
 
6304a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
print('INFO: import modules')
from flask import Flask, request
import json
import pickle
import numpy as np

from required_classes import BertEmbedder, PredictModel


print('INFO: loading model')
try:
    with open('model_finetuned_clear.pkl', 'rb') as f:
        model = pickle.load(f)
    model.batch_size = 1
    print('INFO: model loaded')
except Exception as e:
    print(f"ERROR: loading models failed with: {str(e)}")

def classify_code(text, top_n):
    embed = model._texts2vecs([text])
    probs = model.classifier_code.predict_proba(embed)
    best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
    preds = [{'code': model.classifier_code.classes_[i], 'proba': probs[0][i]} for i in best_n]
    return preds

def classify_group(text, top_n):
    embed = model._texts2vecs([text])
    probs = model.classifier_group.predict_proba(embed)
    best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
    preds = [{'group': model.classifier_group.classes_[i], 'proba': probs[0][i]} for i in best_n]
    return preds


app = Flask(__name__)

@app.get("/")
def test_get():
    return {'hello': 'world'}

@app.route("/test", methods=['POST'])
def test():
    data = request.json
    return {'response': data}

@app.route("/predict", methods=['POST'])
def read_root():
    data = request.json
    print(data)
    text = str(data['text'])
    top_n = int(data['top_n'])

    if top_n < 1:
        return {'error': 'top_n should be geather than 0'}
    if text.strip() == '':
        return {'error': 'text is empty'}
    
    pred_codes = classify_code(text, top_n)
    pred_groups = classify_group(text, top_n)
    result = {
        "icd10": 
            {'result': pred_codes[0]['code'], 'details': pred_codes},
        "dx_group":
            {'result': pred_groups[0]['group'], 'details': pred_groups}
    }
    return result

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)