Spaces:
Sleeping
Sleeping
Update icd9_ui.py
Browse files- icd9_ui.py +196 -196
icd9_ui.py
CHANGED
@@ -1,66 +1,66 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import torch
|
3 |
-
from transformers import LongformerTokenizer, LongformerForSequenceClassification
|
4 |
|
5 |
-
# Load the fine-tuned model and tokenizer
|
6 |
-
model_path = "./clinical_longformer"
|
7 |
-
tokenizer = LongformerTokenizer.from_pretrained(model_path)
|
8 |
-
model = LongformerForSequenceClassification.from_pretrained(model_path)
|
9 |
-
model.eval() # Set the model to evaluation mode
|
10 |
|
11 |
-
# ICD-9 code columns used during training
|
12 |
-
icd9_columns = [
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
]
|
20 |
|
21 |
-
# Function for making predictions
|
22 |
-
def predict_icd9(texts, tokenizer, model, threshold=0.5):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
|
46 |
|
47 |
-
# Streamlit UI
|
48 |
-
st.title("ICD-9 Code Prediction")
|
49 |
-
st.sidebar.header("Model Options")
|
50 |
-
model_option = st.sidebar.selectbox("Select Model", [ "ClinicalLongformer"])
|
51 |
-
threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
|
52 |
-
|
53 |
-
st.write("### Enter Medical Summary")
|
54 |
-
input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
|
55 |
-
|
56 |
-
if st.button("Predict"):
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
|
65 |
# import torch
|
66 |
# import pandas as pd
|
@@ -584,118 +584,118 @@ if st.button("Predict"):
|
|
584 |
# # else:
|
585 |
# # st.info("π Please upload a medical image to begin analysis")
|
586 |
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
|
597 |
-
#
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
|
603 |
-
#
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
|
608 |
-
#
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
|
618 |
-
#
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
|
646 |
-
|
647 |
|
648 |
-
#
|
649 |
-
|
650 |
|
651 |
-
#
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
#
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
#
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
# - List differential diagnoses in order of likelihood
|
700 |
# - Support each diagnosis with observed evidence from the patient's imaging
|
701 |
# - Note any critical or urgent findings
|
@@ -710,49 +710,49 @@ if st.button("Predict"):
|
|
710 |
# - Include key references to support your analysis
|
711 |
# """
|
712 |
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
|
735 |
-
|
736 |
|
737 |
-
|
738 |
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
|
758 |
|
|
|
1 |
+
# import streamlit as st
|
2 |
+
# import torch
|
3 |
+
# from transformers import LongformerTokenizer, LongformerForSequenceClassification
|
4 |
|
5 |
+
# # Load the fine-tuned model and tokenizer
|
6 |
+
# model_path = "./clinical_longformer"
|
7 |
+
# tokenizer = LongformerTokenizer.from_pretrained(model_path)
|
8 |
+
# model = LongformerForSequenceClassification.from_pretrained(model_path)
|
9 |
+
# model.eval() # Set the model to evaluation mode
|
10 |
|
11 |
+
# # ICD-9 code columns used during training
|
12 |
+
# icd9_columns = [
|
13 |
+
# '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
|
14 |
+
# '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
|
15 |
+
# '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
|
16 |
+
# '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
|
17 |
+
# '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
|
18 |
+
# '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
|
19 |
+
# ]
|
20 |
|
21 |
+
# # Function for making predictions
|
22 |
+
# def predict_icd9(texts, tokenizer, model, threshold=0.5):
|
23 |
+
# inputs = tokenizer(
|
24 |
+
# texts,
|
25 |
+
# padding="max_length",
|
26 |
+
# truncation=True,
|
27 |
+
# max_length=512,
|
28 |
+
# return_tensors="pt"
|
29 |
+
# )
|
30 |
|
31 |
+
# with torch.no_grad():
|
32 |
+
# outputs = model(
|
33 |
+
# input_ids=inputs["input_ids"],
|
34 |
+
# attention_mask=inputs["attention_mask"]
|
35 |
+
# )
|
36 |
+
# logits = outputs.logits
|
37 |
+
# probabilities = torch.sigmoid(logits)
|
38 |
+
# predictions = (probabilities > threshold).int()
|
39 |
|
40 |
+
# predicted_icd9 = []
|
41 |
+
# for pred in predictions:
|
42 |
+
# codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
|
43 |
+
# predicted_icd9.append(codes)
|
44 |
|
45 |
+
# return predicted_icd9
|
46 |
|
47 |
+
# # Streamlit UI
|
48 |
+
# st.title("ICD-9 Code Prediction")
|
49 |
+
# st.sidebar.header("Model Options")
|
50 |
+
# model_option = st.sidebar.selectbox("Select Model", [ "ClinicalLongformer"])
|
51 |
+
# threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
|
52 |
+
|
53 |
+
# st.write("### Enter Medical Summary")
|
54 |
+
# input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
|
55 |
+
|
56 |
+
# if st.button("Predict"):
|
57 |
+
# if input_text.strip():
|
58 |
+
# predictions = predict_icd9([input_text], tokenizer, model, threshold)
|
59 |
+
# st.write("### Predicted ICD-9 Codes")
|
60 |
+
# for code in predictions[0]:
|
61 |
+
# st.write(f"- {code}")
|
62 |
+
# else:
|
63 |
+
# st.error("Please enter a medical summary.")
|
64 |
|
65 |
# import torch
|
66 |
# import pandas as pd
|
|
|
584 |
# # else:
|
585 |
# # st.info("π Please upload a medical image to begin analysis")
|
586 |
|
587 |
+
import os
|
588 |
+
import torch
|
589 |
+
import pandas as pd
|
590 |
+
import streamlit as st
|
591 |
+
from PIL import Image
|
592 |
+
from transformers import LongformerTokenizer, LongformerForSequenceClassification
|
593 |
+
from phi.agent import Agent
|
594 |
+
from phi.model.google import Gemini
|
595 |
+
from phi.tools.duckduckgo import DuckDuckGo
|
596 |
|
597 |
+
# Load the fine-tuned ICD-9 model and tokenizer
|
598 |
+
model_path = "./clinical_longformer"
|
599 |
+
tokenizer = LongformerTokenizer.from_pretrained(model_path)
|
600 |
+
model = LongformerForSequenceClassification.from_pretrained(model_path)
|
601 |
+
model.eval() # Set the model to evaluation mode
|
602 |
|
603 |
+
# Load the ICD-9 descriptions from CSV into a dictionary
|
604 |
+
icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv")
|
605 |
+
icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching
|
606 |
+
icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching
|
607 |
|
608 |
+
# ICD-9 code columns used during training
|
609 |
+
icd9_columns = [
|
610 |
+
'038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
|
611 |
+
'287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
|
612 |
+
'39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
|
613 |
+
'486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
|
614 |
+
'88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
|
615 |
+
'995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
|
616 |
+
]
|
617 |
|
618 |
+
# Function for making ICD-9 predictions
|
619 |
+
def predict_icd9(texts, tokenizer, model, threshold=0.5):
|
620 |
+
inputs = tokenizer(
|
621 |
+
texts,
|
622 |
+
padding="max_length",
|
623 |
+
truncation=True,
|
624 |
+
max_length=512,
|
625 |
+
return_tensors="pt"
|
626 |
+
)
|
627 |
+
with torch.no_grad():
|
628 |
+
outputs = model(
|
629 |
+
input_ids=inputs["input_ids"],
|
630 |
+
attention_mask=inputs["attention_mask"]
|
631 |
+
)
|
632 |
+
logits = outputs.logits
|
633 |
+
probabilities = torch.sigmoid(logits)
|
634 |
+
predictions = (probabilities > threshold).int()
|
635 |
|
636 |
+
predicted_icd9 = []
|
637 |
+
for pred in predictions:
|
638 |
+
codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
|
639 |
+
predicted_icd9.append(codes)
|
640 |
|
641 |
+
predictions_with_desc = []
|
642 |
+
for codes in predicted_icd9:
|
643 |
+
code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes]
|
644 |
+
predictions_with_desc.append(code_with_desc)
|
645 |
|
646 |
+
return predictions_with_desc
|
647 |
|
648 |
+
# Define the API key directly in the code
|
649 |
+
GOOGLE_API_KEY = "AIzaSyA24A6egT3L0NAKkkw9QHjfoizp7cJUTaA"
|
650 |
|
651 |
+
# Streamlit UI
|
652 |
+
st.title("Medical Diagnosis Assistant")
|
653 |
+
option = st.selectbox(
|
654 |
+
"Choose Diagnosis Method",
|
655 |
+
("ICD-9 Code Prediction", "Medical Image Analysis")
|
656 |
+
)
|
657 |
+
|
658 |
+
# ICD-9 Code Prediction
|
659 |
+
if option == "ICD-9 Code Prediction":
|
660 |
+
st.write("### Enter Medical Summary")
|
661 |
+
input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
|
662 |
+
|
663 |
+
threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
|
664 |
+
|
665 |
+
if st.button("Predict ICD-9 Codes"):
|
666 |
+
if input_text.strip():
|
667 |
+
predictions = predict_icd9([input_text], tokenizer, model, threshold)
|
668 |
+
st.write("### Predicted ICD-9 Codes and Descriptions")
|
669 |
+
for code, description in predictions[0]:
|
670 |
+
st.write(f"- {code}: {description}")
|
671 |
+
else:
|
672 |
+
st.error("Please enter a medical summary.")
|
673 |
+
|
674 |
+
# Medical Image Analysis
|
675 |
+
elif option == "Medical Image Analysis":
|
676 |
+
medical_agent = Agent(
|
677 |
+
model=Gemini(
|
678 |
+
api_key=GOOGLE_API_KEY,
|
679 |
+
id="gemini-2.0-flash-exp"
|
680 |
+
),
|
681 |
+
tools=[DuckDuckGo()],
|
682 |
+
markdown=True
|
683 |
+
)
|
684 |
|
685 |
+
query = """
|
686 |
+
You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows:
|
687 |
+
### 1. Image Type & Region
|
688 |
+
- Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.)
|
689 |
+
- Identify the patient's anatomical region and positioning
|
690 |
+
- Comment on image quality and technical adequacy
|
691 |
+
### 2. Key Findings
|
692 |
+
- List primary observations systematically
|
693 |
+
- Note any abnormalities in the patient's imaging with precise descriptions
|
694 |
+
- Include measurements and densities where relevant
|
695 |
+
- Describe location, size, shape, and characteristics
|
696 |
+
- Rate severity: Normal/Mild/Moderate/Severe
|
697 |
+
### 3. Diagnostic Assessment
|
698 |
+
- Provide primary diagnosis with confidence level
|
699 |
# - List differential diagnoses in order of likelihood
|
700 |
# - Support each diagnosis with observed evidence from the patient's imaging
|
701 |
# - Note any critical or urgent findings
|
|
|
710 |
# - Include key references to support your analysis
|
711 |
# """
|
712 |
|
713 |
+
upload_container = st.container()
|
714 |
+
image_container = st.container()
|
715 |
+
analysis_container = st.container()
|
716 |
|
717 |
+
with upload_container:
|
718 |
+
uploaded_file = st.file_uploader(
|
719 |
+
"Upload Medical Image",
|
720 |
+
type=["jpg", "jpeg", "png", "dicom"],
|
721 |
+
help="Supported formats: JPG, JPEG, PNG, DICOM"
|
722 |
+
)
|
723 |
|
724 |
+
if uploaded_file is not None:
|
725 |
+
with image_container:
|
726 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
727 |
+
with col2:
|
728 |
+
image = Image.open(uploaded_file)
|
729 |
+
width, height = image.size
|
730 |
+
aspect_ratio = width / height
|
731 |
+
new_width = 500
|
732 |
+
new_height = int(new_width / aspect_ratio)
|
733 |
+
resized_image = image.resize((new_width, new_height))
|
734 |
|
735 |
+
st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True)
|
736 |
|
737 |
+
analyze_button = st.button("π Analyze Image")
|
738 |
|
739 |
+
with analysis_container:
|
740 |
+
if analyze_button:
|
741 |
+
image_path = "temp_medical_image.png"
|
742 |
+
with open(image_path, "wb") as f:
|
743 |
+
f.write(uploaded_file.getbuffer())
|
744 |
|
745 |
+
with st.spinner("π Analyzing image... Please wait."):
|
746 |
+
try:
|
747 |
+
response = medical_agent.run(query, images=[image_path])
|
748 |
+
st.markdown("### π Analysis Results")
|
749 |
+
st.markdown(response.content)
|
750 |
+
except Exception as e:
|
751 |
+
st.error(f"Analysis error: {e}")
|
752 |
+
finally:
|
753 |
+
if os.path.exists(image_path):
|
754 |
+
os.remove(image_path)
|
755 |
+
else:
|
756 |
+
st.info("π Please upload a medical image to begin analysis")
|
757 |
|
758 |
|