Spaces:
Sleeping
Sleeping
import torch | |
import json | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from sklearn.preprocessing import MultiLabelBinarizer | |
import numpy as np | |
import re | |
# ---------------------------------------------------------------------- | |
# Text Preprocessing (same as during training) | |
# ---------------------------------------------------------------------- | |
def preprocess_text(text: str) -> str: | |
""" | |
Perform advanced text cleaning: | |
- Convert to lowercase | |
- Remove bracketed deidentifications [**...**] | |
- Remove excessive punctuation | |
- Convert multiple spaces/newlines to single space | |
- Strip whitespace | |
""" | |
text = text.lower() | |
text = re.sub(r"\[\*\*.*?\*\*\]", " ", text) # remove deidentified brackets | |
text = re.sub(r"([!?.,])\1+", r"\1", text) # collapse repeated punctuation | |
text = re.sub(r"[\r\n\t]+", " ", text) # collapse lines/tabs to space | |
text = re.sub(r"\s+", " ", text) # multiple spaces -> single | |
text = text.strip() | |
return text | |
# ---------------------------------------------------------------------- | |
# Load Trained Model and Artifacts | |
# ---------------------------------------------------------------------- | |
def load_trained_model(model_dir: str): | |
""" | |
Load the trained model, tokenizer, and MultiLabelBinarizer. | |
""" | |
# Load the model | |
model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
model.eval() # Set to evaluation mode | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
# Load the MultiLabelBinarizer | |
with open(f"{model_dir}/mlb_classes.json", "r") as f: | |
top_codes_list = json.load(f) | |
mlb = MultiLabelBinarizer(classes=top_codes_list) | |
mlb.fit([[]]) # Initialize the binarizer | |
return model, tokenizer, mlb | |
# ---------------------------------------------------------------------- | |
# Predict ICD-9 Codes | |
# ---------------------------------------------------------------------- | |
def predict_icd9(input_text: str, model, tokenizer, mlb, max_length=512, threshold=0.5): | |
""" | |
Predict ICD-9 codes for a given clinical text. | |
""" | |
# Preprocess the input text | |
processed_text = preprocess_text(input_text) | |
# Tokenize the input | |
inputs = tokenizer( | |
processed_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=max_length, | |
padding="max_length" | |
) | |
# Perform inference | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# Apply sigmoid to convert logits to probabilities | |
probs = torch.sigmoid(logits).squeeze().cpu().numpy() | |
# Apply threshold to get predicted labels | |
y_pred = (probs > threshold).astype(int) | |
# Decode the predicted labels back to ICD-9 codes | |
predicted_codes = mlb.inverse_transform(np.array([y_pred])) # Ensure 2D array | |
return predicted_codes[0] # Return as a list of codes | |
# ---------------------------------------------------------------------- | |
# Inference Example | |
# ---------------------------------------------------------------------- | |
if __name__ == "__main__": | |
# Directory where the model and artifacts are saved | |
model_dir = "./" | |
# Load the model and related artifacts | |
model, tokenizer, mlb = load_trained_model(model_dir) | |
# Example input text | |
input_text = """Acute nasopharyngitis , true acute abnormality """ | |
# Predict ICD-9 codes | |
predicted_codes = predict_icd9(input_text, model, tokenizer, mlb, threshold= 0.2) | |
# Print the predicted ICD-9 codes | |
print("Predicted ICD-9 Codes:", predicted_codes) | |