top10001 / icd8.py
adithiyyha's picture
Upload 2 files
d137de1 verified
raw
history blame
3.16 kB
import streamlit as st
import torch
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
import re
# ----------------------------------------------------------------------
# Text Preprocessing (same as during training)
# ----------------------------------------------------------------------
def preprocess_text(text: str) -> str:
text = text.lower()
text = re.sub(r"\[\*\*.*?\*\*\]", " ", text)
text = re.sub(r"([!?.,])\1+", r"\1", text)
text = re.sub(r"[\r\n\t]+", " ", text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
# ----------------------------------------------------------------------
# Load Trained Model and Artifacts
# ----------------------------------------------------------------------
@st.cache_resource
def load_trained_model(model_dir: str):
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir)
with open(f"{model_dir}/mlb_classes.json", "r") as f:
top_codes_list = json.load(f)
mlb = MultiLabelBinarizer(classes=top_codes_list)
mlb.fit([[]])
return model, tokenizer, mlb
# ----------------------------------------------------------------------
# Predict ICD-9 Codes
# ----------------------------------------------------------------------
def predict_icd9(input_text: str, model, tokenizer, mlb, max_length=512, threshold=0.5):
processed_text = preprocess_text(input_text)
inputs = tokenizer(
processed_text,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding="max_length"
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().cpu().numpy()
y_pred = (probs > threshold).astype(int)
predicted_codes = mlb.inverse_transform(np.array([y_pred]))
return predicted_codes[0]
# ----------------------------------------------------------------------
# Streamlit App
# ----------------------------------------------------------------------
st.title("ICD-9 Code Prediction")
model_dir = "./final_mode4l"
st.sidebar.header("Model Settings")
threshold = st.sidebar.slider("Prediction Threshold", min_value=0.1, max_value=1.0, value=0.5, step=0.1)
st.write("Enter clinical text below to predict ICD-9 codes.")
input_text = st.text_area("Clinical Text", height=200)
if st.button("Predict"):
if not input_text.strip():
st.error("Please enter valid clinical text.")
else:
st.write("Loading model...")
model, tokenizer, mlb = load_trained_model(model_dir)
st.write("Predicting...")
predicted_codes = predict_icd9(input_text, model, tokenizer, mlb, threshold=threshold)
if predicted_codes:
st.success("Predicted ICD-9 Codes:")
st.write(predicted_codes)
else:
st.warning("No codes were predicted. Try lowering the threshold or using a different input.")