adithiyyha commited on
Commit
6e8a29b
Β·
verified Β·
1 Parent(s): e1eca6d

Update icd9_ui.py

Browse files
Files changed (1) hide show
  1. icd9_ui.py +196 -196
icd9_ui.py CHANGED
@@ -1,66 +1,66 @@
1
- import streamlit as st
2
- import torch
3
- from transformers import LongformerTokenizer, LongformerForSequenceClassification
4
 
5
- # Load the fine-tuned model and tokenizer
6
- model_path = "./clinical_longformer"
7
- tokenizer = LongformerTokenizer.from_pretrained(model_path)
8
- model = LongformerForSequenceClassification.from_pretrained(model_path)
9
- model.eval() # Set the model to evaluation mode
10
 
11
- # ICD-9 code columns used during training
12
- icd9_columns = [
13
- '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
14
- '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
15
- '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
16
- '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
17
- '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
18
- '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
19
- ]
20
 
21
- # Function for making predictions
22
- def predict_icd9(texts, tokenizer, model, threshold=0.5):
23
- inputs = tokenizer(
24
- texts,
25
- padding="max_length",
26
- truncation=True,
27
- max_length=512,
28
- return_tensors="pt"
29
- )
30
 
31
- with torch.no_grad():
32
- outputs = model(
33
- input_ids=inputs["input_ids"],
34
- attention_mask=inputs["attention_mask"]
35
- )
36
- logits = outputs.logits
37
- probabilities = torch.sigmoid(logits)
38
- predictions = (probabilities > threshold).int()
39
 
40
- predicted_icd9 = []
41
- for pred in predictions:
42
- codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
43
- predicted_icd9.append(codes)
44
 
45
- return predicted_icd9
46
 
47
- # Streamlit UI
48
- st.title("ICD-9 Code Prediction")
49
- st.sidebar.header("Model Options")
50
- model_option = st.sidebar.selectbox("Select Model", [ "ClinicalLongformer"])
51
- threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
52
-
53
- st.write("### Enter Medical Summary")
54
- input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
55
-
56
- if st.button("Predict"):
57
- if input_text.strip():
58
- predictions = predict_icd9([input_text], tokenizer, model, threshold)
59
- st.write("### Predicted ICD-9 Codes")
60
- for code in predictions[0]:
61
- st.write(f"- {code}")
62
- else:
63
- st.error("Please enter a medical summary.")
64
 
65
  # import torch
66
  # import pandas as pd
@@ -584,118 +584,118 @@ if st.button("Predict"):
584
  # # else:
585
  # # st.info("πŸ‘† Please upload a medical image to begin analysis")
586
 
587
- # import os
588
- # import torch
589
- # import pandas as pd
590
- # import streamlit as st
591
- # from PIL import Image
592
- # from transformers import LongformerTokenizer, LongformerForSequenceClassification
593
- # from phi.agent import Agent
594
- # from phi.model.google import Gemini
595
- # from phi.tools.duckduckgo import DuckDuckGo
596
 
597
- # # Load the fine-tuned ICD-9 model and tokenizer
598
- # model_path = "./clinical_longformer"
599
- # tokenizer = LongformerTokenizer.from_pretrained(model_path)
600
- # model = LongformerForSequenceClassification.from_pretrained(model_path)
601
- # model.eval() # Set the model to evaluation mode
602
 
603
- # # Load the ICD-9 descriptions from CSV into a dictionary
604
- # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv")
605
- # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching
606
- # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching
607
 
608
- # # ICD-9 code columns used during training
609
- # icd9_columns = [
610
- # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
611
- # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
612
- # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
613
- # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
614
- # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
615
- # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
616
- # ]
617
 
618
- # # Function for making ICD-9 predictions
619
- # def predict_icd9(texts, tokenizer, model, threshold=0.5):
620
- # inputs = tokenizer(
621
- # texts,
622
- # padding="max_length",
623
- # truncation=True,
624
- # max_length=512,
625
- # return_tensors="pt"
626
- # )
627
- # with torch.no_grad():
628
- # outputs = model(
629
- # input_ids=inputs["input_ids"],
630
- # attention_mask=inputs["attention_mask"]
631
- # )
632
- # logits = outputs.logits
633
- # probabilities = torch.sigmoid(logits)
634
- # predictions = (probabilities > threshold).int()
635
 
636
- # predicted_icd9 = []
637
- # for pred in predictions:
638
- # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
639
- # predicted_icd9.append(codes)
640
 
641
- # predictions_with_desc = []
642
- # for codes in predicted_icd9:
643
- # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes]
644
- # predictions_with_desc.append(code_with_desc)
645
 
646
- # return predictions_with_desc
647
 
648
- # # Define the API key directly in the code
649
- # GOOGLE_API_KEY = "AIzaSyA24A6egT3L0NAKkkw9QHjfoizp7cJUTaA"
650
 
651
- # # Streamlit UI
652
- # st.title("Medical Diagnosis Assistant")
653
- # option = st.selectbox(
654
- # "Choose Diagnosis Method",
655
- # ("ICD-9 Code Prediction", "Medical Image Analysis")
656
- # )
657
-
658
- # # ICD-9 Code Prediction
659
- # if option == "ICD-9 Code Prediction":
660
- # st.write("### Enter Medical Summary")
661
- # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
662
-
663
- # threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
664
-
665
- # if st.button("Predict ICD-9 Codes"):
666
- # if input_text.strip():
667
- # predictions = predict_icd9([input_text], tokenizer, model, threshold)
668
- # st.write("### Predicted ICD-9 Codes and Descriptions")
669
- # for code, description in predictions[0]:
670
- # st.write(f"- {code}: {description}")
671
- # else:
672
- # st.error("Please enter a medical summary.")
673
-
674
- # # Medical Image Analysis
675
- # elif option == "Medical Image Analysis":
676
- # medical_agent = Agent(
677
- # model=Gemini(
678
- # api_key=GOOGLE_API_KEY,
679
- # id="gemini-2.0-flash-exp"
680
- # ),
681
- # tools=[DuckDuckGo()],
682
- # markdown=True
683
- # )
684
 
685
- # query = """
686
- # You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows:
687
- # ### 1. Image Type & Region
688
- # - Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.)
689
- # - Identify the patient's anatomical region and positioning
690
- # - Comment on image quality and technical adequacy
691
- # ### 2. Key Findings
692
- # - List primary observations systematically
693
- # - Note any abnormalities in the patient's imaging with precise descriptions
694
- # - Include measurements and densities where relevant
695
- # - Describe location, size, shape, and characteristics
696
- # - Rate severity: Normal/Mild/Moderate/Severe
697
- # ### 3. Diagnostic Assessment
698
- # - Provide primary diagnosis with confidence level
699
  # - List differential diagnoses in order of likelihood
700
  # - Support each diagnosis with observed evidence from the patient's imaging
701
  # - Note any critical or urgent findings
@@ -710,49 +710,49 @@ if st.button("Predict"):
710
  # - Include key references to support your analysis
711
  # """
712
 
713
- # upload_container = st.container()
714
- # image_container = st.container()
715
- # analysis_container = st.container()
716
 
717
- # with upload_container:
718
- # uploaded_file = st.file_uploader(
719
- # "Upload Medical Image",
720
- # type=["jpg", "jpeg", "png", "dicom"],
721
- # help="Supported formats: JPG, JPEG, PNG, DICOM"
722
- # )
723
 
724
- # if uploaded_file is not None:
725
- # with image_container:
726
- # col1, col2, col3 = st.columns([1, 2, 1])
727
- # with col2:
728
- # image = Image.open(uploaded_file)
729
- # width, height = image.size
730
- # aspect_ratio = width / height
731
- # new_width = 500
732
- # new_height = int(new_width / aspect_ratio)
733
- # resized_image = image.resize((new_width, new_height))
734
 
735
- # st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True)
736
 
737
- # analyze_button = st.button("πŸ” Analyze Image")
738
 
739
- # with analysis_container:
740
- # if analyze_button:
741
- # image_path = "temp_medical_image.png"
742
- # with open(image_path, "wb") as f:
743
- # f.write(uploaded_file.getbuffer())
744
 
745
- # with st.spinner("πŸ”„ Analyzing image... Please wait."):
746
- # try:
747
- # response = medical_agent.run(query, images=[image_path])
748
- # st.markdown("### πŸ“‹ Analysis Results")
749
- # st.markdown(response.content)
750
- # except Exception as e:
751
- # st.error(f"Analysis error: {e}")
752
- # finally:
753
- # if os.path.exists(image_path):
754
- # os.remove(image_path)
755
- # else:
756
- # st.info("πŸ‘† Please upload a medical image to begin analysis")
757
 
758
 
 
1
+ # import streamlit as st
2
+ # import torch
3
+ # from transformers import LongformerTokenizer, LongformerForSequenceClassification
4
 
5
+ # # Load the fine-tuned model and tokenizer
6
+ # model_path = "./clinical_longformer"
7
+ # tokenizer = LongformerTokenizer.from_pretrained(model_path)
8
+ # model = LongformerForSequenceClassification.from_pretrained(model_path)
9
+ # model.eval() # Set the model to evaluation mode
10
 
11
+ # # ICD-9 code columns used during training
12
+ # icd9_columns = [
13
+ # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
14
+ # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
15
+ # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
16
+ # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
17
+ # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
18
+ # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
19
+ # ]
20
 
21
+ # # Function for making predictions
22
+ # def predict_icd9(texts, tokenizer, model, threshold=0.5):
23
+ # inputs = tokenizer(
24
+ # texts,
25
+ # padding="max_length",
26
+ # truncation=True,
27
+ # max_length=512,
28
+ # return_tensors="pt"
29
+ # )
30
 
31
+ # with torch.no_grad():
32
+ # outputs = model(
33
+ # input_ids=inputs["input_ids"],
34
+ # attention_mask=inputs["attention_mask"]
35
+ # )
36
+ # logits = outputs.logits
37
+ # probabilities = torch.sigmoid(logits)
38
+ # predictions = (probabilities > threshold).int()
39
 
40
+ # predicted_icd9 = []
41
+ # for pred in predictions:
42
+ # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
43
+ # predicted_icd9.append(codes)
44
 
45
+ # return predicted_icd9
46
 
47
+ # # Streamlit UI
48
+ # st.title("ICD-9 Code Prediction")
49
+ # st.sidebar.header("Model Options")
50
+ # model_option = st.sidebar.selectbox("Select Model", [ "ClinicalLongformer"])
51
+ # threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
52
+
53
+ # st.write("### Enter Medical Summary")
54
+ # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
55
+
56
+ # if st.button("Predict"):
57
+ # if input_text.strip():
58
+ # predictions = predict_icd9([input_text], tokenizer, model, threshold)
59
+ # st.write("### Predicted ICD-9 Codes")
60
+ # for code in predictions[0]:
61
+ # st.write(f"- {code}")
62
+ # else:
63
+ # st.error("Please enter a medical summary.")
64
 
65
  # import torch
66
  # import pandas as pd
 
584
  # # else:
585
  # # st.info("πŸ‘† Please upload a medical image to begin analysis")
586
 
587
+ import os
588
+ import torch
589
+ import pandas as pd
590
+ import streamlit as st
591
+ from PIL import Image
592
+ from transformers import LongformerTokenizer, LongformerForSequenceClassification
593
+ from phi.agent import Agent
594
+ from phi.model.google import Gemini
595
+ from phi.tools.duckduckgo import DuckDuckGo
596
 
597
+ # Load the fine-tuned ICD-9 model and tokenizer
598
+ model_path = "./clinical_longformer"
599
+ tokenizer = LongformerTokenizer.from_pretrained(model_path)
600
+ model = LongformerForSequenceClassification.from_pretrained(model_path)
601
+ model.eval() # Set the model to evaluation mode
602
 
603
+ # Load the ICD-9 descriptions from CSV into a dictionary
604
+ icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv")
605
+ icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching
606
+ icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching
607
 
608
+ # ICD-9 code columns used during training
609
+ icd9_columns = [
610
+ '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9',
611
+ '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61',
612
+ '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0',
613
+ '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0',
614
+ '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15',
615
+ '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61'
616
+ ]
617
 
618
+ # Function for making ICD-9 predictions
619
+ def predict_icd9(texts, tokenizer, model, threshold=0.5):
620
+ inputs = tokenizer(
621
+ texts,
622
+ padding="max_length",
623
+ truncation=True,
624
+ max_length=512,
625
+ return_tensors="pt"
626
+ )
627
+ with torch.no_grad():
628
+ outputs = model(
629
+ input_ids=inputs["input_ids"],
630
+ attention_mask=inputs["attention_mask"]
631
+ )
632
+ logits = outputs.logits
633
+ probabilities = torch.sigmoid(logits)
634
+ predictions = (probabilities > threshold).int()
635
 
636
+ predicted_icd9 = []
637
+ for pred in predictions:
638
+ codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1]
639
+ predicted_icd9.append(codes)
640
 
641
+ predictions_with_desc = []
642
+ for codes in predicted_icd9:
643
+ code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes]
644
+ predictions_with_desc.append(code_with_desc)
645
 
646
+ return predictions_with_desc
647
 
648
+ # Define the API key directly in the code
649
+ GOOGLE_API_KEY = "AIzaSyA24A6egT3L0NAKkkw9QHjfoizp7cJUTaA"
650
 
651
+ # Streamlit UI
652
+ st.title("Medical Diagnosis Assistant")
653
+ option = st.selectbox(
654
+ "Choose Diagnosis Method",
655
+ ("ICD-9 Code Prediction", "Medical Image Analysis")
656
+ )
657
+
658
+ # ICD-9 Code Prediction
659
+ if option == "ICD-9 Code Prediction":
660
+ st.write("### Enter Medical Summary")
661
+ input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...")
662
+
663
+ threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01)
664
+
665
+ if st.button("Predict ICD-9 Codes"):
666
+ if input_text.strip():
667
+ predictions = predict_icd9([input_text], tokenizer, model, threshold)
668
+ st.write("### Predicted ICD-9 Codes and Descriptions")
669
+ for code, description in predictions[0]:
670
+ st.write(f"- {code}: {description}")
671
+ else:
672
+ st.error("Please enter a medical summary.")
673
+
674
+ # Medical Image Analysis
675
+ elif option == "Medical Image Analysis":
676
+ medical_agent = Agent(
677
+ model=Gemini(
678
+ api_key=GOOGLE_API_KEY,
679
+ id="gemini-2.0-flash-exp"
680
+ ),
681
+ tools=[DuckDuckGo()],
682
+ markdown=True
683
+ )
684
 
685
+ query = """
686
+ You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows:
687
+ ### 1. Image Type & Region
688
+ - Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.)
689
+ - Identify the patient's anatomical region and positioning
690
+ - Comment on image quality and technical adequacy
691
+ ### 2. Key Findings
692
+ - List primary observations systematically
693
+ - Note any abnormalities in the patient's imaging with precise descriptions
694
+ - Include measurements and densities where relevant
695
+ - Describe location, size, shape, and characteristics
696
+ - Rate severity: Normal/Mild/Moderate/Severe
697
+ ### 3. Diagnostic Assessment
698
+ - Provide primary diagnosis with confidence level
699
  # - List differential diagnoses in order of likelihood
700
  # - Support each diagnosis with observed evidence from the patient's imaging
701
  # - Note any critical or urgent findings
 
710
  # - Include key references to support your analysis
711
  # """
712
 
713
+ upload_container = st.container()
714
+ image_container = st.container()
715
+ analysis_container = st.container()
716
 
717
+ with upload_container:
718
+ uploaded_file = st.file_uploader(
719
+ "Upload Medical Image",
720
+ type=["jpg", "jpeg", "png", "dicom"],
721
+ help="Supported formats: JPG, JPEG, PNG, DICOM"
722
+ )
723
 
724
+ if uploaded_file is not None:
725
+ with image_container:
726
+ col1, col2, col3 = st.columns([1, 2, 1])
727
+ with col2:
728
+ image = Image.open(uploaded_file)
729
+ width, height = image.size
730
+ aspect_ratio = width / height
731
+ new_width = 500
732
+ new_height = int(new_width / aspect_ratio)
733
+ resized_image = image.resize((new_width, new_height))
734
 
735
+ st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True)
736
 
737
+ analyze_button = st.button("πŸ” Analyze Image")
738
 
739
+ with analysis_container:
740
+ if analyze_button:
741
+ image_path = "temp_medical_image.png"
742
+ with open(image_path, "wb") as f:
743
+ f.write(uploaded_file.getbuffer())
744
 
745
+ with st.spinner("πŸ”„ Analyzing image... Please wait."):
746
+ try:
747
+ response = medical_agent.run(query, images=[image_path])
748
+ st.markdown("### πŸ“‹ Analysis Results")
749
+ st.markdown(response.content)
750
+ except Exception as e:
751
+ st.error(f"Analysis error: {e}")
752
+ finally:
753
+ if os.path.exists(image_path):
754
+ os.remove(image_path)
755
+ else:
756
+ st.info("πŸ‘† Please upload a medical image to begin analysis")
757
 
758