import glob import json import os import time import gradio as gr from openai import OpenAI import xml.etree.ElementTree as ET import re import pandas as pd import api_keys import note_extraction.hf_hosting.prompts as prompts client = OpenAI(api_key=api_keys.OPENAI_API_KEY) model_name = "gpt-4o-2024-08-06" demo = client.beta.assistants.create( name="Information Extractor", instructions="Extract information from this note.", model=model_name, tools=[{"type": "file_search"}], ) def parse_xml_response(xml_string: str) -> pd.DataFrame: """ Parse the XML response from the model and extract all fields into a dictionary, then convert it to a pandas DataFrame with a nested index. """ # Extract only the XML content between the first and last tags xml_content = re.search(r'<.*?>.*', xml_string, re.DOTALL) if xml_content: xml_string = xml_content.group(0) else: print("No valid XML content found.") return pd.DataFrame() try: root = ET.fromstring(xml_string) except ET.ParseError as e: print(f"Error parsing XML: {e}") return pd.DataFrame() result = {} for element in root: tag = element.tag if tag in ['patient_name', 'date_of_birth', 'sex', 'weight', 'date_of_death']: result[tag] = { 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, **{child.tag: child.text.strip() if child.text else None for child in element if child.tag != 'reasoning'} } elif tag in ['traditional_chemo', 'other_cancer_treatments', 'other_conmeds']: if tag not in result: result[tag] = [] reasoning = element.find('reasoning') for item in element: if item.tag in ['drug', 'treatment', 'medication']: date_element = element.find('date') result[tag].append({ 'reasoning': reasoning.text.strip() if reasoning is not None else None, 'name': item.text.strip() if item.text else None, 'date': date_element.text.strip() if date_element is not None and date_element.text else None }) elif tag in ['surgery', 'surgery_outcome', 'metastasis_at_time_of_diagnosis']: result[tag] = { 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, **{child.tag: child.text.strip() if child.text else None for child in element if child.tag != 'reasoning'} } elif tag == 'compounding_pharmacy': result[tag] = { 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None, 'pharmacy': element.find('pharmacy').text.strip() if element.find('pharmacy') is not None else None } elif tag == 'adverse_effects': if tag not in result: result[tag] = [] effect = { 'reasoning': element.find('reasoning').text.strip() if element.find('reasoning') is not None else None } for child in element: if child.tag != 'reasoning': effect[child.tag] = child.text.strip() if child.text else None if effect: result[tag].append(effect) # Convert to nested DataFrame df_data = {} for key, value in result.items(): if isinstance(value, dict): for sub_key, sub_value in value.items(): df_data[(key, '1', sub_key)] = [sub_value] elif isinstance(value, list): for i, item in enumerate(value): for sub_key, sub_value in item.items(): df_data[(key, f"{i+1}", sub_key)] = [sub_value] else: df_data[(key, '1', '')] = [value] # Create multi-index DataFrame df = pd.DataFrame(df_data) df.columns = pd.MultiIndex.from_tuples(df.columns) return df def get_response(prompt, file_id, assistant_id): thread = client.beta.threads.create( messages=[ { "role": "user", "content": prompts.info_prompt, "attachments": [ {"file_id": file_id, "tools": [{"type": "file_search"}]} ], } ] ) run = client.beta.threads.runs.create_and_poll( thread_id=thread.id, assistant_id=assistant_id ) messages = list( client.beta.threads.messages.list(thread_id=thread.id, run_id=run.id) ) assert len(messages) == 1 message_content = messages[0].content[0].text annotations = message_content.annotations for index, annotation in enumerate(annotations): message_content.value = message_content.value.replace(annotation.text, f"") return message_content.value def process(file_content): if not os.path.exists("cache"): os.makedirs("cache") file_name = f"cache/{time.time()}.pdf" with open(file_name, "wb") as f: f.write(file_content) message_file = client.files.create(file=open(file_name, "rb"), purpose="assistants") response = get_response(prompts.info_prompt, message_file.id, demo.id) df = parse_xml_response(response) if df.empty: return "
No valid information could be extracted from the provided file.
" # Transpose the DataFrame df_transposed = df.T.reset_index() df_transposed.columns = ['Category', 'Index', 'Field', 'Value'] df_transposed = df_transposed.sort_values(['Category', 'Index', 'Field']) # Convert to HTML with some basic styling html = df_transposed.to_html(index=False, classes='table table-striped table-bordered', escape=False) # Add some custom CSS for better readability html = f""" {html} """ return html def gradio_interface(): upload_component = gr.File(label="Upload PDF", type="binary") output_component = gr.HTML(label="Extracted Information") demo = gr.Interface( fn=process, inputs=upload_component, outputs=output_component, title="Clinical Note Information Extractor", description="This tool extracts key information from clinical notes in PDF format.", ) demo.queue() demo.launch() if __name__ == "__main__": gradio_interface()