File size: 3,160 Bytes
d137de1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
import torch
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
import re

# ----------------------------------------------------------------------
# Text Preprocessing (same as during training)
# ----------------------------------------------------------------------
def preprocess_text(text: str) -> str:
    text = text.lower()
    text = re.sub(r"\[\*\*.*?\*\*\]", " ", text)
    text = re.sub(r"([!?.,])\1+", r"\1", text)
    text = re.sub(r"[\r\n\t]+", " ", text)
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text

# ----------------------------------------------------------------------
# Load Trained Model and Artifacts
# ----------------------------------------------------------------------
@st.cache_resource
def load_trained_model(model_dir: str):
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    with open(f"{model_dir}/mlb_classes.json", "r") as f:
        top_codes_list = json.load(f)
    mlb = MultiLabelBinarizer(classes=top_codes_list)
    mlb.fit([[]])
    return model, tokenizer, mlb

# ----------------------------------------------------------------------
# Predict ICD-9 Codes
# ----------------------------------------------------------------------
def predict_icd9(input_text: str, model, tokenizer, mlb, max_length=512, threshold=0.5):
    processed_text = preprocess_text(input_text)
    inputs = tokenizer(
        processed_text,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
        padding="max_length"
    )
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.sigmoid(logits).squeeze().cpu().numpy()
    y_pred = (probs > threshold).astype(int)
    predicted_codes = mlb.inverse_transform(np.array([y_pred]))
    return predicted_codes[0]

# ----------------------------------------------------------------------
# Streamlit App
# ----------------------------------------------------------------------
st.title("ICD-9 Code Prediction")

model_dir = "./final_mode4l" 

st.sidebar.header("Model Settings")
threshold = st.sidebar.slider("Prediction Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.1)

st.write("Enter clinical text below to predict ICD-9 codes.")

input_text = st.text_area("Clinical Text", height=200)

if st.button("Predict"):
    if not input_text.strip():
        st.error("Please enter valid clinical text.")
    else:
        st.write("Loading model...")
        model, tokenizer, mlb = load_trained_model(model_dir)
        st.write("Predicting...")
        predicted_codes = predict_icd9(input_text, model, tokenizer, mlb, threshold=threshold)
        if predicted_codes:
            st.success("Predicted ICD-9 Codes:")
            st.write(predicted_codes)
        else:
            st.warning("No codes were predicted. Try lowering the threshold or using a different input.")