adithiyyha commited on
Commit
d137de1
·
verified ·
1 Parent(s): bf72463

Upload 2 files

Browse files
Files changed (2) hide show
  1. icd8.py +80 -0
  2. requirements.txt +5 -0
icd8.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import json
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ from sklearn.preprocessing import MultiLabelBinarizer
6
+ import numpy as np
7
+ import re
8
+
9
+ # ----------------------------------------------------------------------
10
+ # Text Preprocessing (same as during training)
11
+ # ----------------------------------------------------------------------
12
+ def preprocess_text(text: str) -> str:
13
+ text = text.lower()
14
+ text = re.sub(r"\[\*\*.*?\*\*\]", " ", text)
15
+ text = re.sub(r"([!?.,])\1+", r"\1", text)
16
+ text = re.sub(r"[\r\n\t]+", " ", text)
17
+ text = re.sub(r"\s+", " ", text)
18
+ text = text.strip()
19
+ return text
20
+
21
+ # ----------------------------------------------------------------------
22
+ # Load Trained Model and Artifacts
23
+ # ----------------------------------------------------------------------
24
+ @st.cache_resource
25
+ def load_trained_model(model_dir: str):
26
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
27
+ model.eval()
28
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
29
+ with open(f"{model_dir}/mlb_classes.json", "r") as f:
30
+ top_codes_list = json.load(f)
31
+ mlb = MultiLabelBinarizer(classes=top_codes_list)
32
+ mlb.fit([[]])
33
+ return model, tokenizer, mlb
34
+
35
+ # ----------------------------------------------------------------------
36
+ # Predict ICD-9 Codes
37
+ # ----------------------------------------------------------------------
38
+ def predict_icd9(input_text: str, model, tokenizer, mlb, max_length=512, threshold=0.5):
39
+ processed_text = preprocess_text(input_text)
40
+ inputs = tokenizer(
41
+ processed_text,
42
+ return_tensors="pt",
43
+ truncation=True,
44
+ max_length=max_length,
45
+ padding="max_length"
46
+ )
47
+ with torch.no_grad():
48
+ logits = model(**inputs).logits
49
+ probs = torch.sigmoid(logits).squeeze().cpu().numpy()
50
+ y_pred = (probs > threshold).astype(int)
51
+ predicted_codes = mlb.inverse_transform(np.array([y_pred]))
52
+ return predicted_codes[0]
53
+
54
+ # ----------------------------------------------------------------------
55
+ # Streamlit App
56
+ # ----------------------------------------------------------------------
57
+ st.title("ICD-9 Code Prediction")
58
+
59
+ model_dir = "./final_mode4l"
60
+
61
+ st.sidebar.header("Model Settings")
62
+ threshold = st.sidebar.slider("Prediction Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.1)
63
+
64
+ st.write("Enter clinical text below to predict ICD-9 codes.")
65
+
66
+ input_text = st.text_area("Clinical Text", height=200)
67
+
68
+ if st.button("Predict"):
69
+ if not input_text.strip():
70
+ st.error("Please enter valid clinical text.")
71
+ else:
72
+ st.write("Loading model...")
73
+ model, tokenizer, mlb = load_trained_model(model_dir)
74
+ st.write("Predicting...")
75
+ predicted_codes = predict_icd9(input_text, model, tokenizer, mlb, threshold=threshold)
76
+ if predicted_codes:
77
+ st.success("Predicted ICD-9 Codes:")
78
+ st.write(predicted_codes)
79
+ else:
80
+ st.warning("No codes were predicted. Try lowering the threshold or using a different input.")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ scikit-learn
4
+ numpy
5
+ streamlit