import streamlit as st import torch from transformers import LongformerTokenizer, LongformerForSequenceClassification # Load the fine-tuned model and tokenizer model_path = "./clinical_longformer" tokenizer = LongformerTokenizer.from_pretrained(model_path) model = LongformerForSequenceClassification.from_pretrained(model_path) model.eval() # Set the model to evaluation mode # ICD-9 code columns used during training icd9_columns = [ '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' ] # Function for making predictions def predict_icd9(texts, tokenizer, model, threshold=0.5): inputs = tokenizer( texts, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ) with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] ) logits = outputs.logits probabilities = torch.sigmoid(logits) predictions = (probabilities > threshold).int() predicted_icd9 = [] for pred in predictions: codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] predicted_icd9.append(codes) return predicted_icd9 # Streamlit UI st.title("ICD-9 Code Prediction") st.sidebar.header("Model Options") model_option = st.sidebar.selectbox("Select Model", [ "ClinicalLongformer"]) threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) st.write("### Enter Medical Summary") input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") if st.button("Predict"): if input_text.strip(): predictions = predict_icd9([input_text], tokenizer, model, threshold) st.write("### Predicted ICD-9 Codes") for code in predictions[0]: st.write(f"- {code}") else: st.error("Please enter a medical summary.") # import torch # import pandas as pd # import streamlit as st # from transformers import LongformerTokenizer, LongformerForSequenceClassification # # Load the fine-tuned model and tokenizer # model_path = "./clinical_longformer" # tokenizer = LongformerTokenizer.from_pretrained(model_path) # model = LongformerForSequenceClassification.from_pretrained(model_path) # model.eval() # Set the model to evaluation mode # # Load the ICD-9 descriptions from CSV into a dictionary # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching # # ICD-9 code columns used during training # icd9_columns = [ # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' # ] # # Function for making predictions # def predict_icd9(texts, tokenizer, model, threshold=0.5): # inputs = tokenizer( # texts, # padding="max_length", # truncation=True, # max_length=512, # return_tensors="pt" # ) # with torch.no_grad(): # outputs = model( # input_ids=inputs["input_ids"], # attention_mask=inputs["attention_mask"] # ) # logits = outputs.logits # probabilities = torch.sigmoid(logits) # predictions = (probabilities > threshold).int() # predicted_icd9 = [] # for pred in predictions: # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] # predicted_icd9.append(codes) # # Fetch descriptions for the predicted ICD-9 codes from the pre-loaded descriptions # predictions_with_desc = [] # for codes in predicted_icd9: # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes] # predictions_with_desc.append(code_with_desc) # return predictions_with_desc # st.title("ICD-9 Code Prediction") # st.sidebar.header("Model Options") # threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) # st.write("### Enter Medical Summary") # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") # if st.button("Predict"): # if input_text.strip(): # predictions = predict_icd9([input_text], tokenizer, model, threshold) # st.write("### Predicted ICD-9 Codes and Descriptions") # for code, description in predictions[0]: # st.write(f"- {code}: {description}") # else: # st.error("Please enter a medical summary.") # import torch # # # # import pandas as pd # # # # import streamlit as st # # # # from transformers import LongformerTokenizer, LongformerForSequenceClassification # # # # Load the fine-tuned model and tokenizer # # # model_path = "./clinical_longformer" # # # tokenizer = LongformerTokenizer.from_pretrained(model_path) # # # model = LongformerForSequenceClassification.from_pretrained(model_path) # # # model.eval() # Set the model to evaluation mode # # # # Load the ICD-9 descriptions from CSV into a dictionary # # # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file # # # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type # # # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals for matching # # # # Load the ICD-9 to ICD-10 mapping # # # icd9_to_icd10 = {} # # # with open("2015_I9gem.txt", "r") as file: # # # for line in file: # # # parts = line.strip().split() # # # if len(parts) == 3: # # # icd9, icd10, _ = parts # # # icd9_to_icd10[icd9] = icd10 # # # # ICD-9 code columns used during training # # # icd9_columns = [ # # # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', # # # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', # # # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', # # # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', # # # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', # # # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' # # # ] # # # # Function for making predictions and mapping to ICD-10 # # # def predict_icd9(texts, tokenizer, model, threshold=0.5): # # # inputs = tokenizer( # # # texts, # # # padding="max_length", # # # truncation=True, # # # max_length=512, # # # return_tensors="pt" # # # ) # # # with torch.no_grad(): # # # outputs = model( # # # input_ids=inputs["input_ids"], # # # attention_mask=inputs["attention_mask"] # # # ) # # # logits = outputs.logits # # # probabilities = torch.sigmoid(logits) # # # predictions = (probabilities > threshold).int() # # # predicted_icd9 = [] # # # for pred in predictions: # # # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] # # # predicted_icd9.append(codes) # # # # Fetch descriptions and map to ICD-10 codes # # # predictions_with_desc = [] # # # for codes in predicted_icd9: # # # code_with_desc = [] # # # for code in codes: # # # icd9_stripped = code.replace('.', '') # # # icd10_code = icd9_to_icd10.get(icd9_stripped, "Mapping not found") # # # icd9_desc = icd9_descriptions.get(icd9_stripped, "Description not found") # # # code_with_desc.append((code, icd9_desc, icd10_code)) # # # predictions_with_desc.append(code_with_desc) # # # return predictions_with_desc # # # # Streamlit UI # # # st.title("ICD-9 to ICD-10 Code Prediction") # # # st.sidebar.header("Model Options") # # # threshold = st.sidebar.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) # # # st.write("### Enter Medical Summary") # # # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") # # # if st.button("Predict"): # # # if input_text.strip(): # # # predictions = predict_icd9([input_text], tokenizer, model, threshold) # # # st.write("### Predicted ICD-9 and ICD-10 Codes with Descriptions") # # # for icd9_code, description, icd10_code in predictions[0]: # # # st.write(f"- ICD-9: {icd9_code} ({description}) -> ICD-10: {icd10_code}") # # # else: # # # st.error("Please enter a medical summary.") # # # import os # # # import torch # # # import pandas as pd # # # import streamlit as st # # # from PIL import Image # # # from transformers import LongformerTokenizer, LongformerForSequenceClassification # # # from phi.agent import Agent # # # from phi.model.google import Gemini # # # from phi.tools.duckduckgo import DuckDuckGo # # # # Load the fine-tuned ICD-9 model and tokenizer # # # model_path = "./clinical_longformer" # # # tokenizer = LongformerTokenizer.from_pretrained(model_path) # # # model = LongformerForSequenceClassification.from_pretrained(model_path) # # # model.eval() # Set the model to evaluation mode # # # # Load the ICD-9 descriptions from CSV into a dictionary # # # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file # # # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching # # # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching # # # # ICD-9 code columns used during training # # # icd9_columns = [ # # # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', # # # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', # # # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', # # # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', # # # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', # # # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' # # # ] # # # # Function for making ICD-9 predictions # # # def predict_icd9(texts, tokenizer, model, threshold=0.5): # # # inputs = tokenizer( # # # texts, # # # padding="max_length", # # # truncation=True, # # # max_length=512, # # # return_tensors="pt" # # # ) # # # with torch.no_grad(): # # # outputs = model( # # # input_ids=inputs["input_ids"], # # # attention_mask=inputs["attention_mask"] # # # ) # # # logits = outputs.logits # # # probabilities = torch.sigmoid(logits) # # # predictions = (probabilities > threshold).int() # # # predicted_icd9 = [] # # # for pred in predictions: # # # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] # # # predicted_icd9.append(codes) # # # predictions_with_desc = [] # # # for codes in predicted_icd9: # # # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes] # # # predictions_with_desc.append(code_with_desc) # # # return predictions_with_desc # # # Streamlit UI # # # st.title("Medical Diagnosis Assistant") # # # option = st.selectbox( # # # "Choose Diagnosis Method", # # # ("ICD-9 Code Prediction", "Medical Image Analysis") # # # ) # # # # ICD-9 Code Prediction # # # if option == "ICD-9 Code Prediction": # # # st.write("### Enter Medical Summary") # # # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") # # # threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) # # # if st.button("Predict ICD-9 Codes"): # # # if input_text.strip(): # # # predictions = predict_icd9([input_text], tokenizer, model, threshold) # # # st.write("### Predicted ICD-9 Codes and Descriptions") # # # for code, description in predictions[0]: # # # st.write(f"- {code}: {description}") # # # else: # # # st.error("Please enter a medical summary.") # # # Medical Image Analysis # # # elif option == "Medical Image Analysis": # # # if "GOOGLE_API_KEY" not in st.session_state: # # # st.warning("Please enter your Google API Key in the sidebar to continue") # # # else: # # # medical_agent = Agent( # # # model=Gemini( # # # api_key=st.session_state.GOOGLE_API_KEY, # # # id="gemini-2.0-flash-exp" # # # ), # # # tools=[DuckDuckGo()], # # # markdown=True # # # ) # # # query = """ # # # You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows: # # # ### 1. Image Type & Region # # # - Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.) # # # - Identify the patient's anatomical region and positioning # # # - Comment on image quality and technical adequacy # # # ### 2. Key Findings # # # - List primary observations systematically # # # - Note any abnormalities in the patient's imaging with precise descriptions # # # - Include measurements and densities where relevant # # # - Describe location, size, shape, and characteristics # # # - Rate severity: Normal/Mild/Moderate/Severe # # # ### 3. Diagnostic Assessment # # # - Provide primary diagnosis with confidence level # # # - List differential diagnoses in order of likelihood # # # - Support each diagnosis with observed evidence from the patient's imaging # # # - Note any critical or urgent findings # # # ### 4. Patient-Friendly Explanation # # # - Explain the findings in simple, clear language that the patient can understand # # # - Avoid medical jargon or provide clear definitions # # # - Include visual analogies if helpful # # # - Address common patient concerns related to these findings # # # ### 5. Research Context # # # - Use the DuckDuckGo search tool to find recent medical literature about similar cases # # # - Provide a list of relevant medical links # # # - Include key references to support your analysis # # # """ # # # upload_container = st.container() # # # image_container = st.container() # # # analysis_container = st.container() # # # with upload_container: # # # uploaded_file = st.file_uploader( # # # "Upload Medical Image", # # # type=["jpg", "jpeg", "png", "dicom"], # # # help="Supported formats: JPG, JPEG, PNG, DICOM" # # # ) # # # if uploaded_file is not None: # # # with image_container: # # # col1, col2, col3 = st.columns([1, 2, 1]) # # # with col2: # # # image = Image.open(uploaded_file) # # # width, height = image.size # # # aspect_ratio = width / height # # # new_width = 500 # # # new_height = int(new_width / aspect_ratio) # # # resized_image = image.resize((new_width, new_height)) # # # st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True) # # # analyze_button = st.button("🔍 Analyze Image") # # # with analysis_container: # # # if analyze_button: # # # image_path = "temp_medical_image.png" # # # with open(image_path, "wb") as f: # # # f.write(uploaded_file.getbuffer()) # # # with st.spinner("🔄 Analyzing image... Please wait."): # # # try: # # # response = medical_agent.run(query, images=[image_path]) # # # st.markdown("### 📋 Analysis Results") # # # st.markdown(response.content) # # # except Exception as e: # # # st.error(f"Analysis error: {e}") # # # finally: # # # if os.path.exists(image_path): # # # os.remove(image_path) # # # else: # # # st.info("👆 Please upload a medical image to begin analysis") # # import os # # import torch # # import pandas as pd # # import streamlit as st # # from PIL import Image # # from transformers import LongformerTokenizer, LongformerForSequenceClassification # # from phi.agent import Agent # # from phi.model.google import Gemini # # from phi.tools.duckduckgo import DuckDuckGo # # # Sidebar for Google API Key input # # st.sidebar.title("Settings") # # st.sidebar.write("Enter your Google API Key below for the Medical Image Analysis feature.") # # api_key = st.sidebar.text_input("Google API Key", type="password") # # if api_key: # # st.session_state["GOOGLE_API_KEY"] = api_key # # else: # # st.session_state.pop("GOOGLE_API_KEY", None) # # # Load the fine-tuned ICD-9 model and tokenizer # # model_path = "./clinical_longformer" # # tokenizer = LongformerTokenizer.from_pretrained(model_path) # # model = LongformerForSequenceClassification.from_pretrained(model_path) # # model.eval() # Set the model to evaluation mode # # # Load the ICD-9 descriptions from CSV into a dictionary # # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # Adjust the path to your CSV file # # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching # # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching # # # ICD-9 code columns used during training # # icd9_columns = [ # # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', # # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', # # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', # # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', # # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', # # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' # # ] # # # Function for making ICD-9 predictions # # def predict_icd9(texts, tokenizer, model, threshold=0.5): # # inputs = tokenizer( # # texts, # # padding="max_length", # # truncation=True, # # max_length=512, # # return_tensors="pt" # # ) # # with torch.no_grad(): # # outputs = model( # # input_ids=inputs["input_ids"], # # attention_mask=inputs["attention_mask"] # # ) # # logits = outputs.logits # # probabilities = torch.sigmoid(logits) # # predictions = (probabilities > threshold).int() # # predicted_icd9 = [] # # for pred in predictions: # # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] # # predicted_icd9.append(codes) # # predictions_with_desc = [] # # for codes in predicted_icd9: # # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes] # # predictions_with_desc.append(code_with_desc) # # return predictions_with_desc # # # Streamlit UI # # st.title("Medical Diagnosis Assistant") # # option = st.selectbox( # # "Choose Diagnosis Method", # # ("ICD-9 Code Prediction", "Medical Image Analysis") # # ) # # # ICD-9 Code Prediction # # if option == "ICD-9 Code Prediction": # # st.write("### Enter Medical Summary") # # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") # # threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) # # if st.button("Predict ICD-9 Codes"): # # if input_text.strip(): # # predictions = predict_icd9([input_text], tokenizer, model, threshold) # # st.write("### Predicted ICD-9 Codes and Descriptions") # # for code, description in predictions[0]: # # st.write(f"- {code}: {description}") # # else: # # st.error("Please enter a medical summary.") # # # Medical Image Analysis # # elif option == "Medical Image Analysis": # # if "GOOGLE_API_KEY" not in st.session_state: # # st.warning("Please enter your Google API Key in the sidebar to continue") # # else: # # medical_agent = Agent( # # model=Gemini( # # api_key=st.session_state["GOOGLE_API_KEY"], # # id="gemini-2.0-flash-exp" # # ), # # tools=[DuckDuckGo()], # # markdown=True # # ) # # query = """ # # You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows: # # ### 1. Image Type & Region # # - Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.) # # - Identify the patient's anatomical region and positioning # # - Comment on image quality and technical adequacy # # ### 2. Key Findings # # - List primary observations systematically # # - Note any abnormalities in the patient's imaging with precise descriptions # # - Include measurements and densities where relevant # # - Describe location, size, shape, and characteristics # # - Rate severity: Normal/Mild/Moderate/Severe # # ### 3. Diagnostic Assessment # # - Provide primary diagnosis with confidence level # # - List differential diagnoses in order of likelihood # # - Support each diagnosis with observed evidence from the patient's imaging # # - Note any critical or urgent findings # # ### 4. Patient-Friendly Explanation # # - Explain the findings in simple, clear language that the patient can understand # # - Avoid medical jargon or provide clear definitions # # - Include visual analogies if helpful # # - Address common patient concerns related to these findings # # ### 5. Research Context # # - Use the DuckDuckGo search tool to find recent medical literature about similar cases # # - Provide a list of relevant medical links # # - Include key references to support your analysis # # """ # # upload_container = st.container() # # image_container = st.container() # # analysis_container = st.container() # # with upload_container: # # uploaded_file = st.file_uploader( # # "Upload Medical Image", # # type=["jpg", "jpeg", "png", "dicom"], # # help="Supported formats: JPG, JPEG, PNG, DICOM" # # ) # # if uploaded_file is not None: # # with image_container: # # col1, col2, col3 = st.columns([1, 2, 1]) # # with col2: # # image = Image.open(uploaded_file) # # width, height = image.size # # aspect_ratio = width / height # # new_width = 500 # # new_height = int(new_width / aspect_ratio) # # resized_image = image.resize((new_width, new_height)) # # st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True) # # analyze_button = st.button("🔍 Analyze Image") # # with analysis_container: # # if analyze_button: # # image_path = "temp_medical_image.png" # # with open(image_path, "wb") as f: # # f.write(uploaded_file.getbuffer()) # # with st.spinner("🔄 Analyzing image... Please wait."): # # try: # # response = medical_agent.run(query, images=[image_path]) # # st.markdown("### 📋 Analysis Results") # # st.markdown(response.content) # # except Exception as e: # # st.error(f"Analysis error: {e}") # # finally: # # if os.path.exists(image_path): # # os.remove(image_path) # # else: # # st.info("👆 Please upload a medical image to begin analysis") # import os # import torch # import pandas as pd # import streamlit as st # from PIL import Image # from transformers import LongformerTokenizer, LongformerForSequenceClassification # from phi.agent import Agent # from phi.model.google import Gemini # from phi.tools.duckduckgo import DuckDuckGo # # Load the fine-tuned ICD-9 model and tokenizer # model_path = "./clinical_longformer" # tokenizer = LongformerTokenizer.from_pretrained(model_path) # model = LongformerForSequenceClassification.from_pretrained(model_path) # model.eval() # Set the model to evaluation mode # # Load the ICD-9 descriptions from CSV into a dictionary # icd9_desc_df = pd.read_csv("D_ICD_DIAGNOSES.csv") # icd9_desc_df['ICD9_CODE'] = icd9_desc_df['ICD9_CODE'].astype(str) # Ensure ICD9_CODE is string type for matching # icd9_descriptions = dict(zip(icd9_desc_df['ICD9_CODE'].str.replace('.', ''), icd9_desc_df['LONG_TITLE'])) # Remove decimals in ICD9 code for matching # # ICD-9 code columns used during training # icd9_columns = [ # '038.9', '244.9', '250.00', '272.0', '272.4', '276.1', '276.2', '285.1', '285.9', # '287.5', '305.1', '311', '36.15', '37.22', '37.23', '38.91', '38.93', '39.61', # '39.95', '401.9', '403.90', '410.71', '412', '414.01', '424.0', '427.31', '428.0', # '486', '496', '507.0', '511.9', '518.81', '530.81', '584.9', '585.9', '599.0', # '88.56', '88.72', '93.90', '96.04', '96.6', '96.71', '96.72', '99.04', '99.15', # '995.92', 'V15.82', 'V45.81', 'V45.82', 'V58.61' # ] # # Function for making ICD-9 predictions # def predict_icd9(texts, tokenizer, model, threshold=0.5): # inputs = tokenizer( # texts, # padding="max_length", # truncation=True, # max_length=512, # return_tensors="pt" # ) # with torch.no_grad(): # outputs = model( # input_ids=inputs["input_ids"], # attention_mask=inputs["attention_mask"] # ) # logits = outputs.logits # probabilities = torch.sigmoid(logits) # predictions = (probabilities > threshold).int() # predicted_icd9 = [] # for pred in predictions: # codes = [icd9_columns[i] for i, val in enumerate(pred) if val == 1] # predicted_icd9.append(codes) # predictions_with_desc = [] # for codes in predicted_icd9: # code_with_desc = [(code, icd9_descriptions.get(code.replace('.', ''), "Description not found")) for code in codes] # predictions_with_desc.append(code_with_desc) # return predictions_with_desc # # Define the API key directly in the code # GOOGLE_API_KEY = "AIzaSyA24A6egT3L0NAKkkw9QHjfoizp7cJUTaA" # # Streamlit UI # st.title("Medical Diagnosis Assistant") # option = st.selectbox( # "Choose Diagnosis Method", # ("ICD-9 Code Prediction", "Medical Image Analysis") # ) # # ICD-9 Code Prediction # if option == "ICD-9 Code Prediction": # st.write("### Enter Medical Summary") # input_text = st.text_area("Medical Summary", placeholder="Enter clinical notes here...") # threshold = st.slider("Prediction Threshold", 0.0, 1.0, 0.5, 0.01) # if st.button("Predict ICD-9 Codes"): # if input_text.strip(): # predictions = predict_icd9([input_text], tokenizer, model, threshold) # st.write("### Predicted ICD-9 Codes and Descriptions") # for code, description in predictions[0]: # st.write(f"- {code}: {description}") # else: # st.error("Please enter a medical summary.") # # Medical Image Analysis # elif option == "Medical Image Analysis": # medical_agent = Agent( # model=Gemini( # api_key=GOOGLE_API_KEY, # id="gemini-2.0-flash-exp" # ), # tools=[DuckDuckGo()], # markdown=True # ) # query = """ # You are a highly skilled medical imaging expert with extensive knowledge in radiology and diagnostic imaging. Analyze the patient's medical image and structure your response as follows: # ### 1. Image Type & Region # - Specify imaging modality (X-ray/MRI/CT/Ultrasound/etc.) # - Identify the patient's anatomical region and positioning # - Comment on image quality and technical adequacy # ### 2. Key Findings # - List primary observations systematically # - Note any abnormalities in the patient's imaging with precise descriptions # - Include measurements and densities where relevant # - Describe location, size, shape, and characteristics # - Rate severity: Normal/Mild/Moderate/Severe # ### 3. Diagnostic Assessment # - Provide primary diagnosis with confidence level # - List differential diagnoses in order of likelihood # - Support each diagnosis with observed evidence from the patient's imaging # - Note any critical or urgent findings # ### 4. Patient-Friendly Explanation # - Explain the findings in simple, clear language that the patient can understand # - Avoid medical jargon or provide clear definitions # - Include visual analogies if helpful # - Address common patient concerns related to these findings # ### 5. Research Context # - Use the DuckDuckGo search tool to find recent medical literature about similar cases # - Provide a list of relevant medical links # - Include key references to support your analysis # """ # upload_container = st.container() # image_container = st.container() # analysis_container = st.container() # with upload_container: # uploaded_file = st.file_uploader( # "Upload Medical Image", # type=["jpg", "jpeg", "png", "dicom"], # help="Supported formats: JPG, JPEG, PNG, DICOM" # ) # if uploaded_file is not None: # with image_container: # col1, col2, col3 = st.columns([1, 2, 1]) # with col2: # image = Image.open(uploaded_file) # width, height = image.size # aspect_ratio = width / height # new_width = 500 # new_height = int(new_width / aspect_ratio) # resized_image = image.resize((new_width, new_height)) # st.image(resized_image, caption="Uploaded Medical Image", use_container_width=True) # analyze_button = st.button("🔍 Analyze Image") # with analysis_container: # if analyze_button: # image_path = "temp_medical_image.png" # with open(image_path, "wb") as f: # f.write(uploaded_file.getbuffer()) # with st.spinner("🔄 Analyzing image... Please wait."): # try: # response = medical_agent.run(query, images=[image_path]) # st.markdown("### 📋 Analysis Results") # st.markdown(response.content) # except Exception as e: # st.error(f"Analysis error: {e}") # finally: # if os.path.exists(image_path): # os.remove(image_path) # else: # st.info("👆 Please upload a medical image to begin analysis")