lyangas commited on
Commit
488bb56
1 Parent(s): 1efad19

add method predict_code for prediction code by group

Browse files
Files changed (1) hide show
  1. app.py +53 -4
app.py CHANGED
@@ -19,7 +19,7 @@ try:
19
  except Exception as e:
20
  print(f"ERROR: loading embedder failed with: {str(e)}")
21
 
22
-
23
  classifiers_codes = {}
24
  try:
25
  for clf_name in os.listdir('classifiers/codes'):
@@ -28,10 +28,11 @@ try:
28
  with open('classifiers/codes/'+clf_name, 'rb') as f:
29
  model = pickle.load(f)
30
  classifiers_codes[clf_name.split('.')[0]] = model
31
- print(f'INFO: classifier {clf_name} loaded')
32
  except Exception as e:
33
  print(f"ERROR: loading classifiers failed with: {str(e)}")
34
 
 
35
  classifiers_groups = {}
36
  try:
37
  for clf_name in os.listdir('classifiers/groups'):
@@ -40,7 +41,21 @@ try:
40
  with open('classifiers/groups/'+clf_name, 'rb') as f:
41
  model = pickle.load(f)
42
  classifiers_groups[clf_name.split('.')[0]] = model
43
- print(f'INFO: classifier {clf_name} loaded')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  except Exception as e:
45
  print(f"ERROR: loading classifiers failed with: {str(e)}")
46
 
@@ -68,6 +83,17 @@ def classify_group(text, top_n):
68
  preds[clf_name] = clf_preds
69
  return preds
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  def get_top_result(preds):
72
  total_scores = {}
73
  for clf_name, scores in preds.items():
@@ -97,7 +123,7 @@ def test():
97
  return {'response': data}
98
 
99
  @app.route("/predict", methods=['POST'])
100
- def read_root():
101
  data = request.json
102
  base64_bytes = str(data['textB64']).encode("ascii")
103
  sample_string_bytes = base64.b64decode(base64_bytes)
@@ -121,5 +147,28 @@ def read_root():
121
  }
122
  return result
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if __name__ == "__main__":
125
  app.run(host='0.0.0.0', port=7860)
 
19
  except Exception as e:
20
  print(f"ERROR: loading embedder failed with: {str(e)}")
21
 
22
+ print('Loading classifiers of codes')
23
  classifiers_codes = {}
24
  try:
25
  for clf_name in os.listdir('classifiers/codes'):
 
28
  with open('classifiers/codes/'+clf_name, 'rb') as f:
29
  model = pickle.load(f)
30
  classifiers_codes[clf_name.split('.')[0]] = model
31
+ print(f'INFO: codes classifier {clf_name} loaded')
32
  except Exception as e:
33
  print(f"ERROR: loading classifiers failed with: {str(e)}")
34
 
35
+ print('Loading classifiers of groups')
36
  classifiers_groups = {}
37
  try:
38
  for clf_name in os.listdir('classifiers/groups'):
 
41
  with open('classifiers/groups/'+clf_name, 'rb') as f:
42
  model = pickle.load(f)
43
  classifiers_groups[clf_name.split('.')[0]] = model
44
+ print(f'INFO: groups classifier {clf_name} loaded')
45
+ except Exception as e:
46
+ print(f"ERROR: loading classifiers failed with: {str(e)}")
47
+
48
+ print('Loading classifiers in groups')
49
+ groups_models = {}
50
+ try:
51
+ for clf_name in os.listdir('classifiers/codes_in_groups'):
52
+ if '.' == clf_name[0]:
53
+ continue
54
+ with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f:
55
+ model = pickle.load(f)
56
+ group_name = clf_name.replace('_code_clf.pkl', '')
57
+ groups_models[group_name] = model
58
+ print(f'INFO: codes classifier for group {group_name} loaded')
59
  except Exception as e:
60
  print(f"ERROR: loading classifiers failed with: {str(e)}")
61
 
 
83
  preds[clf_name] = clf_preds
84
  return preds
85
 
86
+ def classify_code_by_group(text, group_name, top_n):
87
+ embed = [embedder(text)]
88
+ model = groups_models[group_name]
89
+ probs = model.predict_proba(embed)
90
+ best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
91
+
92
+ top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
93
+ top_cls = model.classes_[best_n[0]]
94
+ all_codes_in_group = model.classes_
95
+ return top_cls, top_n_preds, all_codes_in_group
96
+
97
  def get_top_result(preds):
98
  total_scores = {}
99
  for clf_name, scores in preds.items():
 
123
  return {'response': data}
124
 
125
  @app.route("/predict", methods=['POST'])
126
+ def predict_api():
127
  data = request.json
128
  base64_bytes = str(data['textB64']).encode("ascii")
129
  sample_string_bytes = base64.b64decode(base64_bytes)
 
147
  }
148
  return result
149
 
150
+ @app.route("/predict_code", methods=['POST'])
151
+ def predict_code_api():
152
+ data = request.json
153
+ base64_bytes = str(data['textB64']).encode("ascii")
154
+ sample_string_bytes = base64.b64decode(base64_bytes)
155
+ text = sample_string_bytes.decode("ascii")
156
+ top_n = int(data['top_n'])
157
+ group_name = data['dx_group']
158
+
159
+ if top_n < 1:
160
+ return {'error': 'top_n should be geather than 0'}
161
+ if text.strip() == '':
162
+ return {'error': 'text is empty'}
163
+ if group_name not in groups_models:
164
+ return {'error': 'have no classifier for the group'}
165
+
166
+ top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n)
167
+ result = {
168
+ "icd10":
169
+ {'result': top_pred_code, 'details': pred_codes, 'all_codes': all_codes_in_group}
170
+ }
171
+ return result
172
+
173
  if __name__ == "__main__":
174
  app.run(host='0.0.0.0', port=7860)