File size: 3,858 Bytes
6304a81
d5e3dc4
6304a81
 
 
 
da7535b
6304a81
 
 
 
da7535b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6304a81
da7535b
 
 
 
 
 
 
6304a81
da7535b
 
6304a81
 
da7535b
 
 
 
 
 
 
 
6304a81
 
da7535b
6304a81
da7535b
 
 
 
 
 
 
 
6304a81
 
da7535b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6304a81
 
 
 
 
 
 
 
 
4a7adf1
6304a81
 
 
3a4727b
4a7adf1
afc3da6
d5e3dc4
 
6304a81
 
 
 
 
 
 
 
 
da7535b
 
6304a81
 
da7535b
6304a81
da7535b
6304a81
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
print('INFO: import modules')
import base64
from flask import Flask, request
import json
import pickle
import numpy as np
import os

from required_classes import BertEmbedder, PredictModel


CLS_WEIGHTS = {'mlp': 0.3, 'svc': 0.4, 'xgboost': 0.3}

print('INFO: loading models')
try:
    with open('embedder/embedder.pkl', 'rb') as f:
        embedder = pickle.load(f)
    print('INFO: embedder loaded')
except Exception as e:
    print(f"ERROR: loading embedder failed with: {str(e)}")


classifiers_codes = {}
try:
    for clf_name in os.listdir('classifiers/codes'):
        if '.' == clf_name[0]:
            continue
        with open('classifiers/codes/'+clf_name, 'rb') as f:
            model = pickle.load(f)
            classifiers_codes[clf_name.split('.')[0]] = model
        print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
    print(f"ERROR: loading classifiers failed with: {str(e)}")

classifiers_groups = {}
try:
    for clf_name in os.listdir('classifiers/groups'):
        if '.' == clf_name[0]:
            continue
        with open('classifiers/groups/'+clf_name, 'rb') as f:
            model = pickle.load(f)
            classifiers_groups[clf_name.split('.')[0]] = model
        print(f'INFO: classifier {clf_name} loaded')
except Exception as e:
    print(f"ERROR: loading classifiers failed with: {str(e)}")


def classify_code(text, top_n):
    embed = [embedder(text)]
    preds = {}
    for clf_name in classifiers_codes.keys():
        model = classifiers_codes[clf_name]
        probs = model.predict_proba(embed)
        best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
        clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
        preds[clf_name] = clf_preds
    return preds


def classify_group(text, top_n):
    embed = [embedder(text)]
    preds = {}
    for clf_name in classifiers_groups.keys():
        model = classifiers_groups[clf_name]
        probs = model.predict_proba(embed)
        best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
        clf_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
        preds[clf_name] = clf_preds
    return preds

def get_top_result(preds):
    total_scores = {}
    for clf_name, scores in preds.items():
        clf_name = clf_name.replace('_codes', '').replace('_groups', '')
        for class_name, score in scores.items():
            if class_name in total_scores:
                total_scores[class_name] += CLS_WEIGHTS[clf_name] * score
            else:
                total_scores[class_name] = CLS_WEIGHTS[clf_name] * score

    max_idx = np.array(total_scores.values()).argmax()
    if list(total_scores.values())[max_idx] > 0.5:
        return list(total_scores.keys())[max_idx]
    else:
        return None


app = Flask(__name__)

@app.get("/")
def test_get():
    return {'hello': 'world'}

@app.route("/test", methods=['POST'])
def test():
    data = request.json
    return {'response': data}

@app.route("/predict", methods=['POST'])
def read_root():
    data = request.json
    base64_bytes = str(data['textB64']).encode("ascii")
    sample_string_bytes = base64.b64decode(base64_bytes)
    text = sample_string_bytes.decode("ascii")
    top_n = int(data['top_n'])

    if top_n < 1:
        return {'error': 'top_n should be geather than 0'}
    if text.strip() == '':
        return {'error': 'text is empty'}
    
    pred_codes = classify_code(text, top_n)
    pred_groups = classify_group(text, top_n)
    pred_codes_top = get_top_result(pred_codes)
    pred_groups_top = get_top_result(pred_groups)
    result = {
        "icd10": 
            {'result': pred_codes_top, 'details': pred_codes},
        "dx_group":
            {'result': pred_groups_top, 'details': pred_groups}
    }
    return result

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)