adithiyyha commited on
Commit
70be421
·
verified ·
1 Parent(s): f2415d9

Update icd9_ui.py

Browse files
Files changed (1) hide show
  1. icd9_ui.py +102 -13
icd9_ui.py CHANGED
@@ -62,6 +62,80 @@
62
  # else:
63
  # st.error("Please enter a medical summary.")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  import torch
66
  import pandas as pd
67
  import streamlit as st
@@ -75,8 +149,17 @@ model.eval() # Set the model to evaluation mode
75
 
76
  # Load the ICD-9 descriptions from CSV into a dictionary
77
  icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file
78
- icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching
79
- icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching
 
 
 
 
 
 
 
 
 
80
 
81
  # ICD-9 code columns used during training
82
  icd9_columns = [
@@ -88,7 +171,7 @@ icd9_columns = [
88
  '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
89
  ]
90
 
91
- # Function for making predictions
92
  def predict_icd9(texts, tokenizer, model, threshold=0.5):
93
  inputs = tokenizer(
94
  texts,
@@ -97,7 +180,7 @@ def predict_icd9(texts, tokenizer, model, threshold=0.5):
97
  max_length=512,
98
  return_tensors="pt"
99
  )
100
-
101
  with torch.no_grad():
102
  outputs = model(
103
  input_ids=inputs["input_ids"],
@@ -106,22 +189,27 @@ def predict_icd9(texts, tokenizer, model, threshold=0.5):
106
  logits = outputs.logits
107
  probabilities = torch.sigmoid(logits)
108
  predictions = (probabilities > threshold).int()
109
-
110
  predicted_icd9 = []
111
  for pred in predictions:
112
  codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
113
  predicted_icd9.append(codes)
114
-
115
- # Fetch descriptions for the predicted ICD-9 codes from the pre-loaded descriptions
116
  predictions_with_desc = []
117
  for codes in predicted_icd9:
118
- code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes]
 
 
 
 
 
119
  predictions_with_desc.append(code_with_desc)
120
-
121
  return predictions_with_desc
122
 
123
  # Streamlit UI
124
- st.title("ICD-9 Code Prediction")
125
  st.sidebar.header("Model Options")
126
  threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
127
 
@@ -131,9 +219,10 @@ input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes h
131
  if st.button("Predict"):
132
  if input_text.strip():
133
  predictions = predict_icd9([input_text], tokenizer, model, threshold)
134
- st.write("### Predicted ICD-9 Codes and Descriptions")
135
- for code, description in predictions[0]:
136
- st.write(f"- {code}: {description}")
137
  else:
138
  st.error("Please enter a medical summary.")
139
 
 
 
62
  # else:
63
  # st.error("Please enter a medical summary.")
64
 
65
+ # import torch
66
+ # import pandas as pd
67
+ # import streamlit as st
68
+ # from transformers import LongformerTokenizer, LongformerForSequenceClassification
69
+
70
+ # # Load the fine-tuned model and tokenizer
71
+ # model_path = "./clinical_longformer"
72
+ # tokenizer = LongformerTokenizer.from_pretrained(model_path)
73
+ # model = LongformerForSequenceClassification.from_pretrained(model_path)
74
+ # model.eval() # Set the model to evaluation mode
75
+
76
+ # # Load the ICD-9 descriptions from CSV into a dictionary
77
+ # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file
78
+ # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching
79
+ # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching
80
+
81
+ # # ICD-9 code columns used during training
82
+ # icd9_columns = [
83
+ # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
84
+ # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
85
+ # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
86
+ # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
87
+ # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
88
+ # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
89
+ # ]
90
+
91
+ # # Function for making predictions
92
+ # def predict_icd9(texts, tokenizer, model, threshold=0.5):
93
+ # inputs = tokenizer(
94
+ # texts,
95
+ # padding="max_length",
96
+ # truncation=True,
97
+ # max_length=512,
98
+ # return_tensors="pt"
99
+ # )
100
+
101
+ # with torch.no_grad():
102
+ # outputs = model(
103
+ # input_ids=inputs["input_ids"],
104
+ # attention_mask=inputs["attention_mask"]
105
+ # )
106
+ # logits = outputs.logits
107
+ # probabilities = torch.sigmoid(logits)
108
+ # predictions = (probabilities > threshold).int()
109
+
110
+ # predicted_icd9 = []
111
+ # for pred in predictions:
112
+ # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
113
+ # predicted_icd9.append(codes)
114
+
115
+ # # Fetch descriptions for the predicted ICD-9 codes from the pre-loaded descriptions
116
+ # predictions_with_desc = []
117
+ # for codes in predicted_icd9:
118
+ # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes]
119
+ # predictions_with_desc.append(code_with_desc)
120
+
121
+ # return predictions_with_desc
122
+
123
+ # # Streamlit UI
124
+ # st.title("ICD-9 Code Prediction")
125
+ # st.sidebar.header("Model Options")
126
+ # threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
127
+
128
+ # st.write("### Enter Medical Summary")
129
+ # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
130
+
131
+ # if st.button("Predict"):
132
+ # if input_text.strip():
133
+ # predictions = predict_icd9([input_text], tokenizer, model, threshold)
134
+ # st.write("### Predicted ICD-9 Codes and Descriptions")
135
+ # for code, description in predictions[0]:
136
+ # st.write(f"- {code}: {description}")
137
+ # else:
138
+ # st.error("Please enter a medical summary.")
139
  import torch
140
  import pandas as pd
141
  import streamlit as st
 
149
 
150
  # Load the ICD-9 descriptions from CSV into a dictionary
151
  icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file
152
+ icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type
153
+ icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals for matching
154
+
155
+ # Load the ICD-9 to ICD-10 mapping
156
+ icd9_to_icd10 = {}
157
+ with open("2015_I9gem.txt", "r") as file:
158
+ for line in file:
159
+ parts = line.strip().split()
160
+ if len(parts) == 3:
161
+ icd9, icd10, _ = parts
162
+ icd9_to_icd10[icd9] = icd10
163
 
164
  # ICD-9 code columns used during training
165
  icd9_columns = [
 
171
  '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
172
  ]
173
 
174
+ # Function for making predictions and mapping to ICD-10
175
  def predict_icd9(texts, tokenizer, model, threshold=0.5):
176
  inputs = tokenizer(
177
  texts,
 
180
  max_length=512,
181
  return_tensors="pt"
182
  )
183
+
184
  with torch.no_grad():
185
  outputs = model(
186
  input_ids=inputs["input_ids"],
 
189
  logits = outputs.logits
190
  probabilities = torch.sigmoid(logits)
191
  predictions = (probabilities > threshold).int()
192
+
193
  predicted_icd9 = []
194
  for pred in predictions:
195
  codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
196
  predicted_icd9.append(codes)
197
+
198
+ # Fetch descriptions and map to ICD-10 codes
199
  predictions_with_desc = []
200
  for codes in predicted_icd9:
201
+ code_with_desc = []
202
+ for code in codes:
203
+ icd9_stripped = code.replace('.', '')
204
+ icd10_code = icd9_to_icd10.get(icd9_stripped, "Mapping not found")
205
+ icd9_desc = icd9_descriptions.get(icd9_stripped, "Description not found")
206
+ code_with_desc.append((code, icd9_desc, icd10_code))
207
  predictions_with_desc.append(code_with_desc)
208
+
209
  return predictions_with_desc
210
 
211
  # Streamlit UI
212
+ st.title("ICD-9 to ICD-10 Code Prediction")
213
  st.sidebar.header("Model Options")
214
  threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
215
 
 
219
  if st.button("Predict"):
220
  if input_text.strip():
221
  predictions = predict_icd9([input_text], tokenizer, model, threshold)
222
+ st.write("### Predicted ICD-9 and ICD-10 Codes with Descriptions")
223
+ for icd9_code, description, icd10_code in predictions[0]:
224
+ st.write(f"- ICD-9: {icd9_code} ({description}) -> ICD-10: {icd10_code}")
225
  else:
226
  st.error("Please enter a medical summary.")
227
 
228
+